Update ort and improve the speed of preprocessing
* Add onnx proto * Update ort to 2.0.0-rc.2 * Improve the speed of resizing * Fix yolo-seg bug * Update README.md
10
Cargo.toml
@ -12,17 +12,18 @@ exclude = ["assets/*", "examples/*"]
|
||||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
ndarray = { version = "0.15.6" }
|
||||
ort = { version = "2.0.0-alpha.4", default-features = false, features = [
|
||||
ort = { version = "2.0.0-rc.2", default-features = false, features = [
|
||||
"load-dynamic",
|
||||
"copy-dylibs",
|
||||
"profiling",
|
||||
"half",
|
||||
"ndarray",
|
||||
"cuda",
|
||||
"tensorrt",
|
||||
"coreml",
|
||||
"ureq",
|
||||
"openvino",
|
||||
"rocm",
|
||||
"openvino",
|
||||
"operator-libraries"
|
||||
] }
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
@ -41,3 +42,6 @@ image = "0.25.1"
|
||||
imageproc = { version = "0.24" }
|
||||
ab_glyph = "0.2.23"
|
||||
geo = "0.28.0"
|
||||
prost = "0.12.4"
|
||||
human_bytes = "0.4.3"
|
||||
fast_image_resize = "3.0.4"
|
22
README.md
@ -4,9 +4,10 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
|
||||
|
||||
## Recently Updated
|
||||
|
||||
| Portrait Matting |
|
||||
|
||||
| YOLOv8-Obb |
|
||||
| :----------------------------: |
|
||||
|<img src='examples/modnet/demo.png' width="800px">|
|
||||
|<img src='examples/yolov8/demo-obb-2.png' width="800px">|
|
||||
|
||||
|
||||
| Depth-Anything |
|
||||
@ -14,17 +15,16 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
|
||||
|<img src='examples/depth-anything/demo.png' width="800px">|
|
||||
|
||||
|
||||
| Portrait Matting |
|
||||
| :----------------------------: |
|
||||
|<img src='examples/modnet/demo.png' width="800px">|
|
||||
|
||||
|
||||
| YOLOP-v2 | Face-Parsing | Text-Detection |
|
||||
| :----------------------------: | :------------------------------: | :------------------------------: |
|
||||
|<img src='examples/yolop/demo.png' height="180px">| <img src='examples/face-parsing/demo.png' height="180px"> | <img src='examples/db/demo.png' height="180px"> |
|
||||
|
||||
|
||||
| YOLOv8-Obb |
|
||||
| :----------------------------: |
|
||||
|<img src='examples/yolov8/demo-obb-2.png' width="800px">|
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Supported Models
|
||||
@ -94,7 +94,7 @@ check **[ort guide](https://ort.pyke.io/setup/linking)**
|
||||
|
||||
## Integrate into your own project
|
||||
<details close>
|
||||
<summary>Check Here</summary>
|
||||
<summary>Expand</summary>
|
||||
|
||||
#### 1. Add `usls` as a dependency to your project's `Cargo.toml`
|
||||
|
||||
@ -107,7 +107,7 @@ cargo add --git https://github.com/jamjamjon/usls
|
||||
```Rust
|
||||
let options = Options::default()
|
||||
.with_model("../models/yolov8m-seg-dyn-f16.onnx");
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
```
|
||||
|
||||
- If you want to run your model with TensorRT or CoreML
|
||||
@ -129,7 +129,7 @@ let mut model = YOLO::new(&options)?;
|
||||
|
||||
```Rust
|
||||
let options = Options::default()
|
||||
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
|
||||
.with_confs(&[0.4, 0.15]) // class 0: 0.4, others: 0.15
|
||||
```
|
||||
- Go check [Options](src/options.rs) for more model options.
|
||||
|
||||
|
BIN
assets/liuyifei.png
Normal file
After Width: | Height: | Size: 340 KiB |
Before Width: | Height: | Size: 126 KiB |
@ -15,7 +15,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_profile(false);
|
||||
|
||||
// build model
|
||||
let model = Clip::new(options_visual, options_textual)?;
|
||||
let mut model = Clip::new(options_visual, options_textual)?;
|
||||
|
||||
// texts
|
||||
let texts = vec![
|
||||
|
@ -6,18 +6,18 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 4, 8).into())
|
||||
.with_i02((608, 960, 1280).into())
|
||||
.with_i03((608, 960, 1280).into())
|
||||
// .with_trt(0)
|
||||
.with_confs(&[0.4])
|
||||
.with_min_width(5.0)
|
||||
.with_min_height(12.0)
|
||||
// .with_trt(0)
|
||||
.with_model("ppocr-v4-db-dyn.onnx")?;
|
||||
|
||||
let mut model = DB::new(&options)?;
|
||||
let mut model = DB::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![
|
||||
DataLoader::try_read("./assets/db.png")?,
|
||||
// DataLoader::try_read("./assets/2.jpg")?,
|
||||
DataLoader::try_read("./assets/2.jpg")?,
|
||||
];
|
||||
|
||||
// run
|
||||
|
@ -7,7 +7,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 1, 8).into())
|
||||
.with_i02((384, 512, 1024).into())
|
||||
.with_i03((384, 512, 1024).into());
|
||||
let model = DepthAnything::new(&options)?;
|
||||
let mut model = DepthAnything::new(options)?;
|
||||
|
||||
// load
|
||||
let x = vec![DataLoader::try_read("./assets/2.jpg")?];
|
||||
|
@ -7,7 +7,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 1, 1).into())
|
||||
.with_i02((224, 224, 224).into())
|
||||
.with_i03((224, 224, 224).into());
|
||||
let _model = Dinov2::new(&options)?;
|
||||
let _model = Dinov2::new(options)?;
|
||||
println!("TODO...");
|
||||
|
||||
// query from vector
|
||||
|
@ -10,7 +10,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// .with_trt(0)
|
||||
// .with_fp16(true)
|
||||
.with_confs(&[0.5]);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/nini.png")?];
|
||||
|
@ -8,7 +8,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i02((416, 640, 800).into())
|
||||
.with_i03((416, 640, 800).into())
|
||||
.with_confs(&[0.4]);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
Before Width: | Height: | Size: 93 KiB After Width: | Height: | Size: 128 KiB |
@ -7,10 +7,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((416, 512, 800).into())
|
||||
.with_i03((416, 512, 800).into());
|
||||
let model = MODNet::new(&options)?;
|
||||
let mut model = MODNet::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/portrait.jpg")?];
|
||||
let x = vec![DataLoader::try_read("./assets/liuyifei.png")?];
|
||||
|
||||
// run
|
||||
let y = model.run(&x)?;
|
||||
|
@ -4,9 +4,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// build model
|
||||
let options = Options::default()
|
||||
.with_model("rtdetr-l-f16.onnx")?
|
||||
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
|
||||
.with_confs(&[0.4, 0.15])
|
||||
.with_names(&coco::NAMES_80);
|
||||
let mut model = RTDETR::new(&options)?;
|
||||
let mut model = RTDETR::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
@ -8,7 +8,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_nk(17)
|
||||
.with_confs(&[0.3])
|
||||
.with_kconfs(&[0.5]);
|
||||
let mut model = RTMO::new(&options)?;
|
||||
let mut model = RTMO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
@ -8,7 +8,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_confs(&[0.2])
|
||||
.with_vocab("ppocr_rec_vocab.txt")?
|
||||
.with_model("ppocr-v4-svtr-ch-dyn.onnx")?;
|
||||
let mut model = SVTR::new(&options)?;
|
||||
let mut model = SVTR::new(options)?;
|
||||
|
||||
// load images
|
||||
let dl = DataLoader::default()
|
||||
|
@ -7,9 +7,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((416, 640, 800).into())
|
||||
.with_i03((416, 640, 800).into())
|
||||
.with_confs(&[0.3]) // shoes: 0.2
|
||||
.with_confs(&[0.3])
|
||||
.with_profile(false);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
@ -6,7 +6,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_model("yolopv2-dyn-480x800.onnx")?
|
||||
.with_i00((1, 1, 8).into())
|
||||
.with_confs(&[0.3]);
|
||||
let mut model = YOLOPv2::new(&options)?;
|
||||
let mut model = YOLOPv2::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/car.jpg")?];
|
||||
|
@ -10,13 +10,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_anchors_first(true)
|
||||
.with_yolo_task(YOLOTask::Segment)
|
||||
.with_model("yolov5s-seg.onnx")?
|
||||
.with_trt(0)
|
||||
.with_fp16(true)
|
||||
// .with_trt(0)
|
||||
// .with_fp16(true)
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((224, 640, 800).into())
|
||||
.with_i03((224, 640, 800).into())
|
||||
.with_dry_run(3);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
.with_i03((224, 640, 800).into());
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
Before Width: | Height: | Size: 285 KiB After Width: | Height: | Size: 286 KiB |
@ -8,7 +8,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i02((416, 640, 800).into())
|
||||
.with_i03((416, 640, 800).into())
|
||||
.with_confs(&[0.15]);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/kids.jpg")?];
|
||||
|
@ -3,7 +3,7 @@ use usls::{models::YOLO, Annotator, DataLoader, Options};
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// build model
|
||||
let options = Options::default().with_model("yolov8-falldown-f16.onnx")?;
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/falldown.jpg")?];
|
||||
|
@ -3,7 +3,7 @@ use usls::{models::YOLO, Annotator, DataLoader, Options};
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// build model
|
||||
let options = Options::default().with_model("yolov8-head-f16.onnx")?;
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/kids.jpg")?];
|
||||
|
@ -5,7 +5,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let options = Options::default()
|
||||
.with_model("yolov8-plastic-bag-f16.onnx")?
|
||||
.with_names(&["trash"]);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/trash.jpg")?];
|
||||
|
Before Width: | Height: | Size: 387 KiB After Width: | Height: | Size: 391 KiB |
@ -3,21 +3,25 @@ use usls::{coco, models::YOLO, Annotator, DataLoader, Options};
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// build model
|
||||
let options = Options::default()
|
||||
.with_model("yolov8m-dyn.onnx")?
|
||||
// .with_model("yolov8m-dyn.onnx")?
|
||||
// .with_model("yolov8m-dyn-f16.onnx")?
|
||||
// .with_model("yolov8m-pose-dyn.onnx")?
|
||||
// .with_model("yolov8m-cls-dyn.onnx")?
|
||||
// .with_model("yolov8m-seg-dyn.onnx")?
|
||||
.with_model("yolov8m-seg-dyn.onnx")?
|
||||
// .with_model("yolov8m-obb-dyn.onnx")?
|
||||
// .with_model("yolov8m-oiv7-dyn.onnx")?
|
||||
// .with_trt(0)
|
||||
// .with_fp16(true)
|
||||
// .with_coreml(0)
|
||||
// .with_cuda(3)
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((224, 640, 800).into())
|
||||
.with_i03((224, 640, 800).into())
|
||||
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
|
||||
.with_confs(&[0.4, 0.15]) // class 0: 0.4, others: 0.15
|
||||
.with_names2(&coco::KEYPOINTS_NAMES_17)
|
||||
// .with_dry_run(10)
|
||||
.with_profile(false);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// build dataloader
|
||||
let dl = DataLoader::default()
|
||||
|
@ -7,8 +7,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((416, 640, 800).into())
|
||||
.with_i03((416, 640, 800).into())
|
||||
.with_confs(&[0.4, 0.15]); // person: 0.4, others: 0.15
|
||||
let mut model = YOLO::new(&options)?;
|
||||
.with_confs(&[0.4, 0.15]);
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// load image
|
||||
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
|
||||
|
@ -495,7 +495,7 @@ impl Annotator {
|
||||
|
||||
// keypoint
|
||||
let color = match &self.keypoints_palette {
|
||||
None => self.get_color(i + 10),
|
||||
None => self.get_color(i),
|
||||
Some(keypoints_palette) => keypoints_palette[i],
|
||||
};
|
||||
imageproc::drawing::draw_filled_circle_mut(
|
||||
|
@ -81,7 +81,7 @@ impl DataLoader {
|
||||
let n_new = paths.len();
|
||||
self.paths.append(&mut paths);
|
||||
println!(
|
||||
"{CHECK_MARK} {n_new} files found ({} total)",
|
||||
"{CHECK_MARK} Found images x{n_new} ({} total)",
|
||||
self.paths.len()
|
||||
);
|
||||
Ok(Self {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum Device {
|
||||
Auto(usize), // TODO
|
||||
Cpu(usize),
|
||||
Cuda(usize),
|
||||
Trt(usize),
|
||||
|
@ -1,12 +1,25 @@
|
||||
use anyhow::Result;
|
||||
use half::f16;
|
||||
use human_bytes::human_bytes;
|
||||
use ndarray::{Array, IxDyn};
|
||||
use ort::{
|
||||
ExecutionProvider, ExecutionProviderDispatch, Session, SessionBuilder, TensorElementType,
|
||||
TensorRTExecutionProvider, ValueType,
|
||||
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
|
||||
MINOR_VERSION,
|
||||
};
|
||||
use prost::Message;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{
|
||||
home_dir, onnx, ops::make_divisible, Device, MinOptMax, Options, Ts, CHECK_MARK, CROSS_MARK,
|
||||
};
|
||||
|
||||
use crate::{home_dir, Device, MinOptMax, Options, CHECK_MARK, CROSS_MARK, SAFE_CROSS_MARK};
|
||||
/// Ort Tensor Attrs: name, data_type, dims
|
||||
#[derive(Debug)]
|
||||
pub struct OrtTensorAttr {
|
||||
pub names: Vec<String>,
|
||||
pub dtypes: Vec<ort::TensorElementType>,
|
||||
pub dimss: Vec<Vec<isize>>,
|
||||
}
|
||||
|
||||
/// ONNXRuntime Backend
|
||||
#[derive(Debug)]
|
||||
@ -14,75 +27,57 @@ pub struct OrtEngine {
|
||||
session: Session,
|
||||
device: Device,
|
||||
inputs_minoptmax: Vec<Vec<MinOptMax>>,
|
||||
inames: Vec<String>,
|
||||
ishapes: Vec<Vec<isize>>,
|
||||
idtypes: Vec<TensorElementType>,
|
||||
onames: Vec<String>,
|
||||
oshapes: Vec<Vec<isize>>,
|
||||
odtypes: Vec<TensorElementType>,
|
||||
inputs_attrs: OrtTensorAttr,
|
||||
outputs_attrs: OrtTensorAttr,
|
||||
profile: bool,
|
||||
num_dry_run: usize,
|
||||
model_proto: onnx::ModelProto,
|
||||
params: usize,
|
||||
wbmems: usize,
|
||||
pub ts: Ts,
|
||||
}
|
||||
|
||||
impl OrtEngine {
|
||||
pub fn dry_run(&self) -> Result<()> {
|
||||
if self.num_dry_run == 0 {
|
||||
println!("{SAFE_CROSS_MARK} No dry run count specified, skipping the dry run.");
|
||||
return Ok(());
|
||||
}
|
||||
let mut xs: Vec<Array<f32, IxDyn>> = Vec::new();
|
||||
for i in self.inputs_minoptmax.iter() {
|
||||
let mut x: Vec<usize> = Vec::new();
|
||||
for i_ in i.iter() {
|
||||
x.push(i_.opt as usize);
|
||||
}
|
||||
let x: Array<f32, IxDyn> = Array::ones(x).into_dyn();
|
||||
xs.push(x);
|
||||
}
|
||||
for _ in 0..self.num_dry_run {
|
||||
self.run(xs.as_ref())?;
|
||||
}
|
||||
println!("{CHECK_MARK} Dry run x{}", self.num_dry_run);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn new(config: &Options) -> Result<Self> {
|
||||
ort::init().commit()?;
|
||||
let session = Session::builder()?.with_model_from_file(&config.onnx_path)?;
|
||||
// onnx graph
|
||||
let model_proto = Self::load_onnx(&config.onnx_path)?;
|
||||
let graph = match &model_proto.graph {
|
||||
Some(graph) => graph,
|
||||
None => anyhow::bail!("No graph found in this proto"),
|
||||
};
|
||||
|
||||
// inputs
|
||||
let mut ishapes = Vec::new();
|
||||
let mut idtypes = Vec::new();
|
||||
let mut inames = Vec::new();
|
||||
for x in session.inputs.iter() {
|
||||
inames.push(x.name.to_owned());
|
||||
if let ValueType::Tensor { ty, dimensions } = &x.input_type {
|
||||
ishapes.push(dimensions.iter().map(|x| *x as isize).collect::<Vec<_>>());
|
||||
idtypes.push(*ty);
|
||||
} else {
|
||||
ishapes.push(vec![-1_isize]);
|
||||
idtypes.push(ort::TensorElementType::Float32);
|
||||
}
|
||||
}
|
||||
// outputs
|
||||
let mut oshapes = Vec::new();
|
||||
let mut odtypes = Vec::new();
|
||||
let mut onames = Vec::new();
|
||||
for x in session.outputs.iter() {
|
||||
onames.push(x.name.to_owned());
|
||||
if let ValueType::Tensor { ty, dimensions } = &x.output_type {
|
||||
oshapes.push(dimensions.iter().map(|x| *x as isize).collect::<Vec<_>>());
|
||||
odtypes.push(*ty);
|
||||
} else {
|
||||
oshapes.push(vec![-1_isize]);
|
||||
odtypes.push(ort::TensorElementType::Float32);
|
||||
}
|
||||
// model params & mems
|
||||
let byte_alignment = 16; // 16 for simd; 8 for most
|
||||
let mut params: usize = 0;
|
||||
let mut wbmems: usize = 0;
|
||||
let mut initializer_names: HashSet<&str> = HashSet::new();
|
||||
for tensor_proto in graph.initializer.iter() {
|
||||
initializer_names.insert(&tensor_proto.name);
|
||||
let param = tensor_proto.dims.iter().product::<i64>() as usize;
|
||||
params += param;
|
||||
|
||||
// mems
|
||||
let param = make_divisible(param, byte_alignment);
|
||||
let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize);
|
||||
let wbmem = param * n;
|
||||
wbmems += wbmem;
|
||||
}
|
||||
|
||||
// inputs & outputs
|
||||
let inputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.input)?;
|
||||
let outputs_attrs = Self::io_from_onnx_value_info(&initializer_names, &graph.output)?;
|
||||
|
||||
// inputs minoptmax
|
||||
let mut inputs_minoptmax: Vec<Vec<MinOptMax>> = Vec::new();
|
||||
for (i, dims) in ishapes.iter().enumerate() {
|
||||
for (i, dims) in inputs_attrs.dimss.iter().enumerate() {
|
||||
let mut v_: Vec<MinOptMax> = Vec::new();
|
||||
for (ii, &x) in dims.iter().enumerate() {
|
||||
let x_default: MinOptMax = (ishapes[i][ii], ishapes[i][ii], ishapes[i][ii]).into();
|
||||
let x_default: MinOptMax = (
|
||||
inputs_attrs.dimss[i][ii],
|
||||
inputs_attrs.dimss[i][ii],
|
||||
inputs_attrs.dimss[i][ii],
|
||||
)
|
||||
.into();
|
||||
let x: MinOptMax = match (i, ii) {
|
||||
(0, 0) => Self::_set_ixx(x, &config.i00, i, ii).unwrap_or(x_default),
|
||||
(0, 1) => Self::_set_ixx(x, &config.i01, i, ii).unwrap_or(x_default),
|
||||
@ -115,72 +110,80 @@ impl OrtEngine {
|
||||
inputs_minoptmax.push(v_);
|
||||
}
|
||||
|
||||
// build again
|
||||
// build
|
||||
ort::init().commit()?;
|
||||
let builder = Session::builder()?;
|
||||
let device = config.device.to_owned();
|
||||
let _ep = match device {
|
||||
Device::Trt(device_id) => Self::build_trt(
|
||||
&inames,
|
||||
&inputs_minoptmax,
|
||||
&builder,
|
||||
device_id,
|
||||
config.trt_int8_enable,
|
||||
config.trt_fp16_enable,
|
||||
config.trt_engine_cache_enable,
|
||||
)?,
|
||||
Device::Cuda(device_id) => Self::build_cuda(&builder, device_id)?,
|
||||
Device::CoreML(_) => {
|
||||
let coreml = ort::CoreMLExecutionProvider::default()
|
||||
.with_subgraphs()
|
||||
// .with_ane_only()
|
||||
.build();
|
||||
if coreml.is_available()? && coreml.register(&builder).is_ok() {
|
||||
println!("{CHECK_MARK} Using CoreML");
|
||||
coreml
|
||||
} else {
|
||||
println!("{CROSS_MARK} CoreML initialization failed");
|
||||
println!("{CHECK_MARK} Using CPU");
|
||||
ort::CPUExecutionProvider::default().build()
|
||||
}
|
||||
let mut device = config.device.to_owned();
|
||||
match device {
|
||||
Device::Trt(device_id) => {
|
||||
Self::build_trt(
|
||||
&inputs_attrs.names,
|
||||
&inputs_minoptmax,
|
||||
&builder,
|
||||
device_id,
|
||||
config.trt_int8_enable,
|
||||
config.trt_fp16_enable,
|
||||
config.trt_engine_cache_enable,
|
||||
)?;
|
||||
}
|
||||
Device::Cuda(device_id) => {
|
||||
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
|
||||
device = Device::Cpu(0);
|
||||
println!("{err}");
|
||||
})
|
||||
}
|
||||
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
|
||||
device = Device::Cpu(0);
|
||||
println!("{err}");
|
||||
}),
|
||||
Device::Cpu(_) => {
|
||||
println!("{CHECK_MARK} Using CPU");
|
||||
ort::CPUExecutionProvider::default().build()
|
||||
} // _ => todo!(),
|
||||
};
|
||||
Self::build_cpu(&builder)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
|
||||
let session = builder
|
||||
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
||||
.with_model_from_file(&config.onnx_path)?;
|
||||
.commit_from_file(&config.onnx_path)?;
|
||||
|
||||
// summary
|
||||
println!(
|
||||
"{CHECK_MARK} ORT: 1.{MINOR_VERSION}.x | Opset: {} | EP: {:?} | Dtype: {:?} | Parameters: {}",
|
||||
model_proto.opset_import[0].version,
|
||||
device,
|
||||
inputs_attrs.dtypes,
|
||||
human_bytes(params as f64),
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
session,
|
||||
device,
|
||||
inputs_minoptmax,
|
||||
inames,
|
||||
ishapes,
|
||||
idtypes,
|
||||
onames,
|
||||
oshapes,
|
||||
odtypes,
|
||||
inputs_attrs,
|
||||
outputs_attrs,
|
||||
profile: config.profile,
|
||||
num_dry_run: config.num_dry_run,
|
||||
model_proto,
|
||||
params,
|
||||
wbmems,
|
||||
ts: Ts::default(),
|
||||
})
|
||||
}
|
||||
|
||||
fn build_trt(
|
||||
inames: &[String],
|
||||
names: &[String],
|
||||
inputs_minoptmax: &[Vec<MinOptMax>],
|
||||
builder: &SessionBuilder,
|
||||
device_id: usize,
|
||||
int8_enable: bool,
|
||||
fp16_enable: bool,
|
||||
engine_cache_enable: bool,
|
||||
) -> Result<ExecutionProviderDispatch> {
|
||||
) -> Result<()> {
|
||||
// auto generate shapes
|
||||
let mut spec_min = String::new();
|
||||
let mut spec_opt = String::new();
|
||||
let mut spec_max = String::new();
|
||||
for (i, name) in inames.iter().enumerate() {
|
||||
for (i, name) in names.iter().enumerate() {
|
||||
if i != 0 {
|
||||
spec_min.push(',');
|
||||
spec_opt.push(',');
|
||||
@ -217,81 +220,137 @@ impl OrtEngine {
|
||||
.with_timing_cache(false)
|
||||
.with_profile_min_shapes(spec_min)
|
||||
.with_profile_opt_shapes(spec_opt)
|
||||
.with_profile_max_shapes(spec_max)
|
||||
.build();
|
||||
.with_profile_max_shapes(spec_max);
|
||||
if trt.is_available()? && trt.register(builder).is_ok() {
|
||||
println!(
|
||||
"{CHECK_MARK} Using TensorRT (Initial model serialization may require a wait)"
|
||||
);
|
||||
Ok(trt)
|
||||
println!("\n🐢 Initial model serialization with TensorRT may require a wait...\n");
|
||||
Ok(())
|
||||
} else {
|
||||
println!("{CROSS_MARK} TensorRT initialization failed. Try CUDA...");
|
||||
Self::build_cuda(builder, device_id)
|
||||
anyhow::bail!("{CROSS_MARK} TensorRT initialization failed")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<ExecutionProviderDispatch> {
|
||||
let cuda = ort::CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id as i32)
|
||||
.build();
|
||||
if cuda.is_available()? && cuda.register(builder).is_ok() {
|
||||
println!("{CHECK_MARK} Using CUDA");
|
||||
Ok(cuda)
|
||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
|
||||
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
println!("{CROSS_MARK} CUDA initialization failed");
|
||||
println!("{CHECK_MARK} Using CPU");
|
||||
Ok(ort::CPUExecutionProvider::default().build())
|
||||
anyhow::bail!("{CROSS_MARK} CUDA initialization failed")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&self, xs: &[Array<f32, IxDyn>]) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// input
|
||||
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("{CROSS_MARK} CoreML initialization failed")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CUDAExecutionProvider::default();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("{CROSS_MARK} CPU initialization failed")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dry_run(&mut self) -> Result<()> {
|
||||
if self.num_dry_run > 0 {
|
||||
let mut xs: Vec<Array<f32, IxDyn>> = Vec::new();
|
||||
for i in self.inputs_minoptmax.iter() {
|
||||
let mut x: Vec<usize> = Vec::new();
|
||||
for i_ in i.iter() {
|
||||
x.push(i_.opt as usize);
|
||||
}
|
||||
let x: Array<f32, IxDyn> = Array::ones(x).into_dyn();
|
||||
xs.push(x);
|
||||
}
|
||||
for _ in 0..self.num_dry_run {
|
||||
self.run(xs.as_ref())?;
|
||||
}
|
||||
self.ts.clear();
|
||||
println!("{CHECK_MARK} Dryrun x{}", self.num_dry_run);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[Array<f32, IxDyn>]) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// inputs dtype alignment
|
||||
let mut xs_ = Vec::new();
|
||||
let t_pre = std::time::Instant::now();
|
||||
for (idtype, x) in self.idtypes.iter().zip(xs.iter()) {
|
||||
let x_ = match idtype {
|
||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?,
|
||||
TensorElementType::Float16 => ort::Value::from_array(x.mapv(f16::from_f32).view())?,
|
||||
TensorElementType::Int32 => ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?,
|
||||
TensorElementType::Int64 => ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?,
|
||||
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.iter()) {
|
||||
let x_ = match &idtype {
|
||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
|
||||
TensorElementType::Float16 => {
|
||||
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int32 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int64 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
xs_.push(x_);
|
||||
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
|
||||
}
|
||||
let t_pre = t_pre.elapsed();
|
||||
self.ts.add_or_push(0, t_pre);
|
||||
|
||||
// inference
|
||||
let t_run = std::time::Instant::now();
|
||||
let ys = self.session.run(xs_.as_ref())?;
|
||||
let outputs = self.session.run(&xs_[..])?;
|
||||
let t_run = t_run.elapsed();
|
||||
self.ts.add_or_push(1, t_run);
|
||||
|
||||
// oputput
|
||||
let mut ys_ = Vec::new();
|
||||
let mut ys = Vec::new();
|
||||
let t_post = std::time::Instant::now();
|
||||
|
||||
for (dtype, name) in self.odtypes.iter().zip(self.onames.iter()) {
|
||||
let y = &ys[name.as_str()];
|
||||
for (dtype, name) in self
|
||||
.outputs_attrs
|
||||
.dtypes
|
||||
.iter()
|
||||
.zip(self.outputs_attrs.names.iter())
|
||||
{
|
||||
let y = &outputs[name.as_str()];
|
||||
let y_ = match &dtype {
|
||||
TensorElementType::Float32 => y.extract_tensor::<f32>()?.view().to_owned(),
|
||||
TensorElementType::Float16 => y.extract_tensor::<f16>()?.view().mapv(f16::to_f32),
|
||||
TensorElementType::Float32 => y.try_extract_tensor::<f32>()?.view().into_owned(),
|
||||
TensorElementType::Float16 => y
|
||||
.try_extract_tensor::<f16>()?
|
||||
.view()
|
||||
.mapv(f16::to_f32)
|
||||
.into_owned(),
|
||||
TensorElementType::Int64 => y
|
||||
.extract_tensor::<i64>()?
|
||||
.try_extract_tensor::<i64>()?
|
||||
.view()
|
||||
.to_owned()
|
||||
.mapv(|x| x as f32),
|
||||
.mapv(|x| x as f32)
|
||||
.into_owned(),
|
||||
_ => todo!(),
|
||||
};
|
||||
ys_.push(y_);
|
||||
ys.push(y_);
|
||||
}
|
||||
let t_post = t_post.elapsed();
|
||||
self.ts.add_or_push(2, t_post);
|
||||
|
||||
if self.profile {
|
||||
let len = 10usize;
|
||||
let n = 4usize;
|
||||
println!(
|
||||
"[Profile] batch: {:?} => {:.4?} (i: {t_pre:.4?}, run: {t_run:.4?}, o: {t_post:.4?})",
|
||||
self.batch().opt,
|
||||
t_pre + t_run + t_post
|
||||
"[Profile] {:>len$.n$?} ({:>len$.n$?} avg) [alignment: {:>len$.n$?} ({:>len$.n$?} avg) | inference: {:>len$.n$?} ({:>len$.n$?} avg) | to_f32: {:>len$.n$?} ({:>len$.n$?} avg)]",
|
||||
t_pre + t_run + t_post,
|
||||
self.ts.avg(),
|
||||
t_pre,
|
||||
self.ts.avgi(0),
|
||||
t_run,
|
||||
self.ts.avgi(1),
|
||||
t_post,
|
||||
self.ts.avgi(2),
|
||||
);
|
||||
}
|
||||
Ok(ys_)
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn _set_ixx(x: isize, ixx: &Option<MinOptMax>, i: usize, ii: usize) -> Option<MinOptMax> {
|
||||
@ -310,28 +369,200 @@ impl OrtEngine {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nbytes_from_onnx_dtype_id(x: usize) -> usize {
|
||||
match x {
|
||||
7 | 11 | 13 => 8, // i64, f64, u64
|
||||
1 | 6 | 12 => 4, // f32, i32, u32
|
||||
10 | 16 | 5 | 4 => 2, // f16, bf16, i16, u16
|
||||
2 | 3 | 9 => 1, // u8, i8, bool
|
||||
8 => 4, // string(1~4)
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
|
||||
match x {
|
||||
ort::TensorElementType::Float64
|
||||
| ort::TensorElementType::Uint64
|
||||
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||
ort::TensorElementType::Float32
|
||||
| ort::TensorElementType::Uint32
|
||||
| ort::TensorElementType::Int32
|
||||
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||
ort::TensorElementType::Float16
|
||||
| ort::TensorElementType::Bfloat16
|
||||
| ort::TensorElementType::Int16
|
||||
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||
ort::TensorElementType::Uint8
|
||||
| ort::TensorElementType::Int8
|
||||
| ort::TensorElementType::Bool => 1, // u8, i8, bool
|
||||
}
|
||||
}
|
||||
|
||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
|
||||
match value {
|
||||
0 => None,
|
||||
1 => Some(ort::TensorElementType::Float32),
|
||||
2 => Some(ort::TensorElementType::Uint8),
|
||||
3 => Some(ort::TensorElementType::Int8),
|
||||
4 => Some(ort::TensorElementType::Uint16),
|
||||
5 => Some(ort::TensorElementType::Int16),
|
||||
6 => Some(ort::TensorElementType::Int32),
|
||||
7 => Some(ort::TensorElementType::Int64),
|
||||
8 => Some(ort::TensorElementType::String),
|
||||
9 => Some(ort::TensorElementType::Bool),
|
||||
10 => Some(ort::TensorElementType::Float16),
|
||||
11 => Some(ort::TensorElementType::Float64),
|
||||
12 => Some(ort::TensorElementType::Uint32),
|
||||
13 => Some(ort::TensorElementType::Uint64),
|
||||
14 => None, // COMPLEX64
|
||||
15 => None, // COMPLEX128
|
||||
16 => Some(ort::TensorElementType::Bfloat16),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn i_from_session(session: &ort::Session) -> Result<OrtTensorAttr> {
|
||||
let mut dimss = Vec::new();
|
||||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for x in session.inputs.iter() {
|
||||
names.push(x.name.to_owned());
|
||||
if let ort::ValueType::Tensor { ty, dimensions } = &x.input_type {
|
||||
dimss.push(dimensions.iter().map(|x| *x as isize).collect::<Vec<_>>());
|
||||
dtypes.push(*ty);
|
||||
} else {
|
||||
dimss.push(vec![-1_isize]);
|
||||
dtypes.push(ort::TensorElementType::Float32);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(OrtTensorAttr {
|
||||
names,
|
||||
dimss,
|
||||
dtypes,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn o_from_session(session: &ort::Session) -> Result<OrtTensorAttr> {
|
||||
let mut dimss = Vec::new();
|
||||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for x in session.outputs.iter() {
|
||||
names.push(x.name.to_owned());
|
||||
if let ort::ValueType::Tensor { ty, dimensions } = &x.output_type {
|
||||
dimss.push(dimensions.iter().map(|x| *x as isize).collect::<Vec<_>>());
|
||||
dtypes.push(*ty);
|
||||
} else {
|
||||
dimss.push(vec![-1_isize]);
|
||||
dtypes.push(ort::TensorElementType::Float32);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(OrtTensorAttr {
|
||||
names,
|
||||
dimss,
|
||||
dtypes,
|
||||
})
|
||||
}
|
||||
|
||||
fn io_from_onnx_value_info(
|
||||
initializer_names: &HashSet<&str>,
|
||||
value_info: &[onnx::ValueInfoProto],
|
||||
) -> Result<OrtTensorAttr> {
|
||||
let mut dimss: Vec<Vec<isize>> = Vec::new();
|
||||
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
|
||||
let mut names: Vec<String> = Vec::new();
|
||||
for v in value_info.iter() {
|
||||
if initializer_names.contains(v.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
names.push(v.name.to_string());
|
||||
let dtype = match &v.r#type {
|
||||
Some(dtype) => dtype,
|
||||
None => continue,
|
||||
};
|
||||
let dtype = match &dtype.value {
|
||||
Some(dtype) => dtype,
|
||||
None => continue,
|
||||
};
|
||||
let tensor = match dtype {
|
||||
onnx::type_proto::Value::TensorType(tensor) => tensor,
|
||||
_ => continue,
|
||||
};
|
||||
let tensor_type = tensor.elem_type;
|
||||
let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) {
|
||||
Some(dtype) => dtype,
|
||||
None => continue,
|
||||
// None => anyhow::bail!("DType not supported"),
|
||||
};
|
||||
dtypes.push(tensor_type);
|
||||
|
||||
let shapes = match &tensor.shape {
|
||||
Some(shapes) => shapes,
|
||||
None => continue,
|
||||
// None => anyhow::bail!("DType has no shapes"),
|
||||
};
|
||||
let mut shape_: Vec<isize> = Vec::new();
|
||||
for shape in shapes.dim.iter() {
|
||||
match &shape.value {
|
||||
None => continue,
|
||||
Some(value) => match value {
|
||||
onnx::tensor_shape_proto::dimension::Value::DimValue(x) => {
|
||||
shape_.push(*x as isize);
|
||||
}
|
||||
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
|
||||
shape_.push(-1isize);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
dimss.push(shape_);
|
||||
}
|
||||
Ok(OrtTensorAttr {
|
||||
dimss,
|
||||
dtypes,
|
||||
names,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_onnx<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||
let f = std::fs::read(p)?;
|
||||
Ok(onnx::ModelProto::decode(f.as_slice())?)
|
||||
}
|
||||
|
||||
pub fn oshapes(&self) -> &Vec<Vec<isize>> {
|
||||
&self.oshapes
|
||||
&self.outputs_attrs.dimss
|
||||
}
|
||||
|
||||
pub fn odimss(&self) -> &Vec<Vec<isize>> {
|
||||
&self.outputs_attrs.dimss
|
||||
}
|
||||
|
||||
pub fn onames(&self) -> &Vec<String> {
|
||||
&self.onames
|
||||
&self.outputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
&self.odtypes
|
||||
&self.outputs_attrs.dtypes
|
||||
}
|
||||
|
||||
pub fn ishapes(&self) -> &Vec<Vec<isize>> {
|
||||
&self.ishapes
|
||||
&self.inputs_attrs.dimss
|
||||
}
|
||||
|
||||
pub fn idimss(&self) -> &Vec<Vec<isize>> {
|
||||
&self.inputs_attrs.dimss
|
||||
}
|
||||
|
||||
pub fn inames(&self) -> &Vec<String> {
|
||||
&self.inames
|
||||
&self.inputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
&self.idtypes
|
||||
&self.inputs_attrs.dtypes
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
@ -355,7 +586,7 @@ impl OrtEngine {
|
||||
}
|
||||
|
||||
pub fn is_batch_dyn(&self) -> bool {
|
||||
self.ishapes[0][0] == -1
|
||||
self.ishapes()[0][0] == -1
|
||||
}
|
||||
|
||||
pub fn try_fetch(&self, key: &str) -> Option<String> {
|
||||
@ -372,7 +603,31 @@ impl OrtEngine {
|
||||
&self.session
|
||||
}
|
||||
|
||||
pub fn version(&self) -> Option<String> {
|
||||
self.try_fetch("version")
|
||||
pub fn ir_version(&self) -> usize {
|
||||
self.model_proto.ir_version as usize
|
||||
}
|
||||
|
||||
pub fn opset_version(&self) -> usize {
|
||||
self.model_proto.opset_import[0].version as usize
|
||||
}
|
||||
|
||||
pub fn producer_name(&self) -> String {
|
||||
self.model_proto.producer_name.to_string()
|
||||
}
|
||||
|
||||
pub fn producer_version(&self) -> String {
|
||||
self.model_proto.producer_version.to_string()
|
||||
}
|
||||
|
||||
pub fn model_version(&self) -> usize {
|
||||
self.model_proto.model_version as usize
|
||||
}
|
||||
|
||||
pub fn parameters(&self) -> usize {
|
||||
self.params
|
||||
}
|
||||
|
||||
pub fn memory_weights(&self) -> usize {
|
||||
self.wbmems
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// A value composed of Min-Opt-Max
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct MinOptMax {
|
||||
pub min: isize,
|
||||
pub opt: isize,
|
||||
@ -16,6 +16,16 @@ impl Default for MinOptMax {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MinOptMax {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("")
|
||||
.field("Min", &self.min)
|
||||
.field("Opt", &self.opt)
|
||||
.field("Max", &self.max)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(isize, isize, isize)> for MinOptMax {
|
||||
fn from((min, opt, max): (isize, isize, isize)) -> Self {
|
||||
let min = min.min(opt);
|
||||
|
@ -6,9 +6,11 @@ mod engine;
|
||||
mod logits_sampler;
|
||||
mod metric;
|
||||
mod min_opt_max;
|
||||
pub mod onnx;
|
||||
pub mod ops;
|
||||
mod options;
|
||||
mod tokenizer_stream;
|
||||
mod ts;
|
||||
|
||||
pub use annotator::Annotator;
|
||||
pub use dataloader::DataLoader;
|
||||
@ -20,3 +22,4 @@ pub use metric::Metric;
|
||||
pub use min_opt_max::MinOptMax;
|
||||
pub use options::Options;
|
||||
pub use tokenizer_stream::TokenizerStream;
|
||||
pub use ts::Ts;
|
||||
|
1061
src/core/onnx.rs
Normal file
195
src/core/ops.rs
@ -1,6 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use fast_image_resize as fr;
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer};
|
||||
use ndarray::{Array, Axis, IxDyn};
|
||||
use ndarray::{s, Array, Axis, IxDyn};
|
||||
|
||||
pub fn standardize(xs: Array<f32, IxDyn>, mean: &[f32], std: &[f32]) -> Array<f32, IxDyn> {
|
||||
let mean = Array::from_shape_vec((1, mean.len(), 1, 1), mean.to_vec()).unwrap();
|
||||
@ -26,18 +27,57 @@ pub fn scale_wh(w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
|
||||
(r, (w0 * r).round(), (h0 * r).round())
|
||||
}
|
||||
|
||||
pub fn resize(xs: &[DynamicImage], height: u32, width: u32) -> Result<Array<f32, IxDyn>> {
|
||||
pub fn build_resizer(ty: &str) -> fr::Resizer {
|
||||
let ty = match ty {
|
||||
"box" => fr::FilterType::Box,
|
||||
"bilinear" => fr::FilterType::Bilinear,
|
||||
"hamming" => fr::FilterType::Hamming,
|
||||
"catmullRom" => fr::FilterType::CatmullRom,
|
||||
"mitchell" => fr::FilterType::Mitchell,
|
||||
"lanczos3" => fr::FilterType::Lanczos3,
|
||||
_ => todo!(),
|
||||
};
|
||||
fr::Resizer::new(fr::ResizeAlg::Convolution(ty))
|
||||
}
|
||||
|
||||
pub fn resize(
|
||||
xs: &[DynamicImage],
|
||||
height: u32,
|
||||
width: u32,
|
||||
filter: &str,
|
||||
) -> Result<Array<f32, IxDyn>> {
|
||||
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
|
||||
let mut resizer = build_resizer(filter);
|
||||
for (idx, x) in xs.iter().enumerate() {
|
||||
let img = x.resize_exact(width, height, image::imageops::FilterType::Triangle);
|
||||
for (x, y, rgb) in img.pixels() {
|
||||
let x = x as usize;
|
||||
let y = y as usize;
|
||||
let [r, g, b, _] = rgb.0;
|
||||
ys[[idx, 0, y, x]] = r as f32;
|
||||
ys[[idx, 1, y, x]] = g as f32;
|
||||
ys[[idx, 2, y, x]] = b as f32;
|
||||
}
|
||||
// src
|
||||
let src_image = fr::Image::from_vec_u8(
|
||||
std::num::NonZeroU32::new(x.width()).unwrap(),
|
||||
std::num::NonZeroU32::new(x.height()).unwrap(),
|
||||
x.to_rgb8().into_raw(),
|
||||
fr::PixelType::U8x3,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// dst
|
||||
let mut dst_image = fr::Image::new(
|
||||
std::num::NonZeroU32::new(width).unwrap(),
|
||||
std::num::NonZeroU32::new(height).unwrap(),
|
||||
src_image.pixel_type(),
|
||||
);
|
||||
|
||||
// resize
|
||||
resizer
|
||||
.resize(&src_image.view(), &mut dst_image.view_mut())
|
||||
.unwrap();
|
||||
let buffer = dst_image.into_vec();
|
||||
|
||||
// to ndarray
|
||||
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
|
||||
.unwrap()
|
||||
.mapv(|x| x as f32)
|
||||
.permuted_axes([2, 0, 1]);
|
||||
let mut data = ys.slice_mut(s![idx, .., .., ..]);
|
||||
data.assign(&y_);
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
@ -46,27 +86,62 @@ pub fn letterbox(
|
||||
xs: &[DynamicImage],
|
||||
height: u32,
|
||||
width: u32,
|
||||
bg: f32,
|
||||
filter: &str,
|
||||
bg: Option<u8>,
|
||||
) -> Result<Array<f32, IxDyn>> {
|
||||
// TODO: refactor
|
||||
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
|
||||
ys.fill(bg);
|
||||
let mut resizer = build_resizer(filter);
|
||||
for (idx, x) in xs.iter().enumerate() {
|
||||
let (w0, h0) = x.dimensions();
|
||||
let (_, w_new, h_new) = scale_wh(w0 as f32, h0 as f32, width as f32, height as f32);
|
||||
let img = x.resize_exact(
|
||||
w_new as u32,
|
||||
h_new as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
for (x, y, rgb) in img.pixels() {
|
||||
let x = x as usize;
|
||||
let y = y as usize;
|
||||
let [r, g, b, _] = rgb.0;
|
||||
ys[[idx, 0, y, x]] = r as f32;
|
||||
ys[[idx, 1, y, x]] = g as f32;
|
||||
ys[[idx, 2, y, x]] = b as f32;
|
||||
}
|
||||
|
||||
// src
|
||||
let src_image = fr::Image::from_vec_u8(
|
||||
std::num::NonZeroU32::new(w0).unwrap(),
|
||||
std::num::NonZeroU32::new(h0).unwrap(),
|
||||
x.to_rgb8().into_raw(),
|
||||
fr::PixelType::U8x3,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// dst
|
||||
let mut dst_image = match bg {
|
||||
Some(bg) => fr::Image::from_vec_u8(
|
||||
std::num::NonZeroU32::new(width).unwrap(),
|
||||
std::num::NonZeroU32::new(height).unwrap(),
|
||||
vec![bg; 3 * height as usize * width as usize],
|
||||
src_image.pixel_type(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => fr::Image::new(
|
||||
std::num::NonZeroU32::new(width).unwrap(),
|
||||
std::num::NonZeroU32::new(height).unwrap(),
|
||||
src_image.pixel_type(),
|
||||
),
|
||||
};
|
||||
|
||||
// mutable view
|
||||
let mut dst_view = dst_image
|
||||
.view_mut()
|
||||
.crop(
|
||||
0,
|
||||
0,
|
||||
std::num::NonZeroU32::new(w_new as u32).unwrap(),
|
||||
std::num::NonZeroU32::new(h_new as u32).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// resize
|
||||
resizer.resize(&src_image.view(), &mut dst_view).unwrap();
|
||||
let buffer = dst_image.into_vec();
|
||||
|
||||
// to ndarray
|
||||
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
|
||||
.unwrap()
|
||||
.mapv(|x| x as f32)
|
||||
.permuted_axes([2, 0, 1]);
|
||||
let mut data = ys.slice_mut(s![idx, .., .., ..]);
|
||||
data.assign(&y_);
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
@ -75,23 +150,63 @@ pub fn resize_with_fixed_height(
|
||||
xs: &[DynamicImage],
|
||||
height: u32,
|
||||
width: u32,
|
||||
bg: f32,
|
||||
filter: &str,
|
||||
bg: Option<u8>,
|
||||
) -> Result<Array<f32, IxDyn>> {
|
||||
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
|
||||
ys.fill(bg);
|
||||
let mut resizer = build_resizer(filter);
|
||||
for (idx, x) in xs.iter().enumerate() {
|
||||
let (w0, h0) = x.dimensions();
|
||||
let h_new = height;
|
||||
let w_new = height * w0 / h0;
|
||||
let img = x.resize_exact(w_new, h_new, image::imageops::FilterType::CatmullRom);
|
||||
for (x, y, rgb) in img.pixels() {
|
||||
let x = x as usize;
|
||||
let y = y as usize;
|
||||
let [r, g, b, _] = rgb.0;
|
||||
ys[[idx, 0, y, x]] = r as f32;
|
||||
ys[[idx, 1, y, x]] = g as f32;
|
||||
ys[[idx, 2, y, x]] = b as f32;
|
||||
}
|
||||
|
||||
// src
|
||||
let src_image = fr::Image::from_vec_u8(
|
||||
std::num::NonZeroU32::new(w0).unwrap(),
|
||||
std::num::NonZeroU32::new(h0).unwrap(),
|
||||
x.to_rgb8().into_raw(),
|
||||
fr::PixelType::U8x3,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// dst
|
||||
let mut dst_image = match bg {
|
||||
Some(bg) => fr::Image::from_vec_u8(
|
||||
std::num::NonZeroU32::new(width).unwrap(),
|
||||
std::num::NonZeroU32::new(height).unwrap(),
|
||||
vec![bg; 3 * height as usize * width as usize],
|
||||
src_image.pixel_type(),
|
||||
)
|
||||
.unwrap(),
|
||||
None => fr::Image::new(
|
||||
std::num::NonZeroU32::new(width).unwrap(),
|
||||
std::num::NonZeroU32::new(height).unwrap(),
|
||||
src_image.pixel_type(),
|
||||
),
|
||||
};
|
||||
|
||||
// mutable view
|
||||
let mut dst_view = dst_image
|
||||
.view_mut()
|
||||
.crop(
|
||||
0,
|
||||
0,
|
||||
std::num::NonZeroU32::new(w_new).unwrap(),
|
||||
std::num::NonZeroU32::new(h_new).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// resize
|
||||
resizer.resize(&src_image.view(), &mut dst_view).unwrap();
|
||||
let buffer = dst_image.into_vec();
|
||||
|
||||
// to ndarray
|
||||
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
|
||||
.unwrap()
|
||||
.mapv(|x| x as f32)
|
||||
.permuted_axes([2, 0, 1]);
|
||||
let mut data = ys.slice_mut(s![idx, .., .., ..]);
|
||||
data.assign(&y_);
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
@ -109,3 +224,7 @@ pub fn descale_mask(mask: DynamicImage, w0: f32, h0: f32, w1: f32, h1: f32) -> D
|
||||
let mask = mask.crop(0, 0, w as u32, h as u32);
|
||||
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
|
||||
}
|
||||
|
||||
pub fn make_divisible(x: usize, divisor: usize) -> usize {
|
||||
(x - 1 + divisor) / divisor * divisor
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ impl Default for Options {
|
||||
onnx_path: String::new(),
|
||||
device: Device::Cuda(0),
|
||||
profile: false,
|
||||
num_dry_run: 3,
|
||||
num_dry_run: 5,
|
||||
i00: None,
|
||||
i01: None,
|
||||
i02: None,
|
||||
|
55
src/core/ts.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Ts {
|
||||
n: usize,
|
||||
ts: Vec<Duration>,
|
||||
// pub t0: Duration,
|
||||
// pub t1: Duration,
|
||||
// pub t2: Duration,
|
||||
// pub t3: Duration,
|
||||
// pub t4: Duration,
|
||||
}
|
||||
|
||||
impl Ts {
|
||||
pub fn total(&self) -> Duration {
|
||||
self.ts.iter().sum::<Duration>()
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.n / self.ts.len()
|
||||
}
|
||||
|
||||
pub fn avg(&self) -> Duration {
|
||||
self.total() / self.n() as u32
|
||||
}
|
||||
|
||||
pub fn avgi(&self, i: usize) -> Duration {
|
||||
if i >= self.ts.len() {
|
||||
panic!("Index out of bound");
|
||||
}
|
||||
self.ts[i] / self.n() as u32
|
||||
}
|
||||
|
||||
pub fn ts(&self) -> &Vec<Duration> {
|
||||
&self.ts
|
||||
}
|
||||
|
||||
// TODO: overhead?
|
||||
pub fn add_or_push(&mut self, i: usize, x: Duration) {
|
||||
match self.ts.get_mut(i) {
|
||||
Some(elem) => *elem += x,
|
||||
None => {
|
||||
if i >= self.ts.len() {
|
||||
self.ts.push(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
self.n += 1;
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.n = Default::default();
|
||||
self.ts = Default::default();
|
||||
}
|
||||
}
|
@ -19,8 +19,8 @@ pub struct Blip {
|
||||
|
||||
impl Blip {
|
||||
pub fn new(options_visual: Options, options_textual: Options) -> Result<Self> {
|
||||
let visual = OrtEngine::new(&options_visual)?;
|
||||
let textual = OrtEngine::new(&options_textual)?;
|
||||
let mut visual = OrtEngine::new(&options_visual)?;
|
||||
let mut textual = OrtEngine::new(&options_textual)?;
|
||||
let (batch_visual, batch_textual, height, width) = (
|
||||
visual.batch().to_owned(),
|
||||
textual.batch().to_owned(),
|
||||
@ -42,18 +42,21 @@ impl Blip {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_images(&self, xs: &[DynamicImage]) -> Result<Embedding> {
|
||||
let xs_ = ops::resize(xs, self.height.opt as u32, self.width.opt as u32)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result<Embedding> {
|
||||
let xs_ = ops::resize(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"bilinear",
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0., 255.);
|
||||
let xs_ = ops::standardize(
|
||||
xs_,
|
||||
&[0.48145466, 0.4578275, 0.40821073],
|
||||
&[0.26862954, 0.2613026, 0.2757771],
|
||||
);
|
||||
let ys: Vec<Array<f32, IxDyn>> = self.visual.run(&[xs_])?;
|
||||
// let ys = ys[0].to_owned();
|
||||
Ok(Embedding::new(ys[0].to_owned()))
|
||||
// Ok(ys)
|
||||
}
|
||||
|
||||
pub fn caption(
|
||||
|
@ -19,8 +19,8 @@ pub struct Clip {
|
||||
impl Clip {
|
||||
pub fn new(options_visual: Options, options_textual: Options) -> Result<Self> {
|
||||
let context_length = 77;
|
||||
let visual = OrtEngine::new(&options_visual)?;
|
||||
let textual = OrtEngine::new(&options_textual)?;
|
||||
let mut visual = OrtEngine::new(&options_visual)?;
|
||||
let mut textual = OrtEngine::new(&options_textual)?;
|
||||
let (batch_visual, batch_textual, height, width) = (
|
||||
visual.inputs_minoptmax()[0][0].to_owned(),
|
||||
textual.inputs_minoptmax()[0][0].to_owned(),
|
||||
@ -52,9 +52,14 @@ impl Clip {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_images(&self, xs: &[DynamicImage]) -> Result<Embedding> {
|
||||
let xs_ = ops::resize(xs, self.height.opt as u32, self.width.opt as u32)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result<Embedding> {
|
||||
let xs_ = ops::resize(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"bilinear",
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0., 255.);
|
||||
let xs_ = ops::standardize(
|
||||
xs_,
|
||||
&[0.48145466, 0.4578275, 0.40821073],
|
||||
@ -64,7 +69,7 @@ impl Clip {
|
||||
Ok(Embedding::new(ys[0].to_owned()))
|
||||
}
|
||||
|
||||
pub fn encode_texts(&self, texts: &[String]) -> Result<Embedding> {
|
||||
pub fn encode_texts(&mut self, texts: &[String]) -> Result<Embedding> {
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
.encode_batch(texts.to_owned(), false)
|
||||
|
@ -17,8 +17,8 @@ pub struct DB {
|
||||
}
|
||||
|
||||
impl DB {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -27,8 +27,8 @@ impl DB {
|
||||
let confs = DynConf::new(&options.confs, 1);
|
||||
let unclip_ratio = options.unclip_ratio;
|
||||
let binary_thresh = 0.2;
|
||||
let min_width = options.min_width.unwrap_or(0.0);
|
||||
let min_height = options.min_height.unwrap_or(0.0);
|
||||
let min_width = options.min_width.unwrap_or(0.);
|
||||
let min_height = options.min_height.unwrap_or(0.);
|
||||
engine.dry_run()?;
|
||||
|
||||
Ok(Self {
|
||||
@ -45,8 +45,14 @@ impl DB {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::letterbox(xs, self.height.opt as u32, self.width.opt as u32, 144.0)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let xs_ = ops::letterbox(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"bilinear",
|
||||
Some(114),
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0., 255.);
|
||||
let xs_ = ops::standardize(xs_, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
|
@ -12,8 +12,8 @@ pub struct DepthAnything {
|
||||
}
|
||||
|
||||
impl DepthAnything {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -29,8 +29,13 @@ impl DepthAnything {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(&self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::resize(xs, self.height.opt as u32, self.width.opt as u32)?;
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::resize(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"lanczos3",
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let xs_ = ops::standardize(xs_, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
|
@ -21,14 +21,14 @@ pub struct Dinov2 {
|
||||
}
|
||||
|
||||
impl Dinov2 {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.inputs_minoptmax()[0][0].to_owned(),
|
||||
engine.inputs_minoptmax()[0][2].to_owned(),
|
||||
engine.inputs_minoptmax()[0][3].to_owned(),
|
||||
);
|
||||
let which = match &options.onnx_path {
|
||||
let which = match options.onnx_path {
|
||||
s if s.contains("b14") => Model::B,
|
||||
s if s.contains("s14") => Model::S,
|
||||
_ => todo!(),
|
||||
@ -49,7 +49,12 @@ impl Dinov2 {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Array<f32, IxDyn>> {
|
||||
let xs_ = ops::resize(xs, self.height.opt as u32, self.width.opt as u32)?;
|
||||
let xs_ = ops::resize(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"lanczos3",
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let xs_ = ops::standardize(
|
||||
xs_,
|
||||
|
@ -13,8 +13,8 @@ pub struct MODNet {
|
||||
}
|
||||
|
||||
impl MODNet {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -30,9 +30,14 @@ impl MODNet {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(&self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::resize(xs, self.height.opt as u32, self.width.opt as u32)?;
|
||||
let xs_ = ops::normalize(xs_, 127.5, 255.0);
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::resize(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"lanczos3",
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 127.5, 255.);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
}
|
||||
|
@ -17,14 +17,14 @@ pub struct RTDETR {
|
||||
}
|
||||
|
||||
impl RTDETR {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.inputs_minoptmax()[0][0].to_owned(),
|
||||
engine.inputs_minoptmax()[0][2].to_owned(),
|
||||
engine.inputs_minoptmax()[0][3].to_owned(),
|
||||
);
|
||||
let names: Option<_> = match &options.names {
|
||||
let names: Option<_> = match options.names {
|
||||
None => engine.try_fetch("names").map(|names| {
|
||||
let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap();
|
||||
let mut names_ = vec![];
|
||||
@ -56,7 +56,13 @@ impl RTDETR {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32, 144.0)?;
|
||||
let xs_ = ops::letterbox(
|
||||
xs,
|
||||
self.height() as u32,
|
||||
self.width() as u32,
|
||||
"catmullRom",
|
||||
Some(114),
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
|
@ -15,8 +15,8 @@ pub struct RTMO {
|
||||
}
|
||||
|
||||
impl RTMO {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -39,7 +39,13 @@ impl RTMO {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32, 114.0)?;
|
||||
let xs_ = ops::letterbox(
|
||||
xs,
|
||||
self.height() as u32,
|
||||
self.width() as u32,
|
||||
"catmullRom",
|
||||
Some(114),
|
||||
)?;
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ pub struct SVTR {
|
||||
}
|
||||
|
||||
impl SVTR {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -24,7 +24,7 @@ impl SVTR {
|
||||
);
|
||||
let confs = DynConf::new(&options.confs, 1);
|
||||
let mut vocab: Vec<_> =
|
||||
std::fs::read_to_string(options.vocab.as_ref().expect("No vocabulary found"))?
|
||||
std::fs::read_to_string(options.vocab.expect("No vocabulary found"))?
|
||||
.lines()
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
@ -43,8 +43,13 @@ impl SVTR {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ =
|
||||
ops::resize_with_fixed_height(xs, self.height.opt as u32, self.width.opt as u32, 0.0)?;
|
||||
let xs_ = ops::resize_with_fixed_height(
|
||||
xs,
|
||||
self.height.opt as u32,
|
||||
self.width.opt as u32,
|
||||
"bilinear",
|
||||
Some(0),
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let ys: Vec<Array<f32, IxDyn>> = self.engine.run(&[xs_])?;
|
||||
let ys = ys[0].to_owned();
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use clap::ValueEnum;
|
||||
use image::DynamicImage;
|
||||
use image::{DynamicImage, ImageBuffer};
|
||||
use ndarray::{s, Array, Axis, IxDyn};
|
||||
use regex::Regex;
|
||||
|
||||
@ -40,16 +40,16 @@ pub struct YOLO {
|
||||
}
|
||||
|
||||
impl YOLO {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
engine.width().to_owned(),
|
||||
);
|
||||
|
||||
let task = match &options.yolo_task {
|
||||
Some(task) => task.to_owned(),
|
||||
let task = match options.yolo_task {
|
||||
Some(task) => task,
|
||||
None => match engine
|
||||
.try_fetch("task")
|
||||
.unwrap_or("detect".to_string())
|
||||
@ -60,12 +60,12 @@ impl YOLO {
|
||||
"pose" => YOLOTask::Pose,
|
||||
"segment" => YOLOTask::Segment,
|
||||
"obb" => YOLOTask::Obb,
|
||||
x => todo!("{:?} is not supported for now!", x),
|
||||
x => todo!("Not supported: {x:?} "),
|
||||
},
|
||||
};
|
||||
|
||||
// try from custom class names, and then model metadata
|
||||
let mut names = options.names.to_owned().or(Self::fetch_names(&engine));
|
||||
let mut names = options.names.or(Self::fetch_names(&engine));
|
||||
let nc = match options.nc {
|
||||
Some(nc) => {
|
||||
match &names {
|
||||
@ -88,7 +88,7 @@ impl YOLO {
|
||||
},
|
||||
};
|
||||
|
||||
let names_kpt = options.names2.to_owned().or(None);
|
||||
let names_kpt = options.names2.or(None);
|
||||
|
||||
// try from model metadata
|
||||
let nk = engine
|
||||
@ -131,10 +131,18 @@ impl YOLO {
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = match self.task {
|
||||
YOLOTask::Classify => ops::resize(xs, self.height() as u32, self.width() as u32)?,
|
||||
_ => ops::letterbox(xs, self.height() as u32, self.width() as u32, 114.0)?,
|
||||
YOLOTask::Classify => {
|
||||
ops::resize(xs, self.height() as u32, self.width() as u32, "bilinear")?
|
||||
}
|
||||
_ => ops::letterbox(
|
||||
xs,
|
||||
self.height() as u32,
|
||||
self.width() as u32,
|
||||
"catmullRom",
|
||||
Some(114),
|
||||
)?,
|
||||
};
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let xs_ = ops::normalize(xs_, 0., 255.);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
}
|
||||
@ -333,29 +341,35 @@ impl YOLO {
|
||||
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
|
||||
|
||||
// build image from ndarray
|
||||
let mask_im = ops::build_dyn_image_from_raw(
|
||||
mask.into_raw_vec(),
|
||||
nw as u32,
|
||||
nh as u32,
|
||||
);
|
||||
let mask: ImageBuffer<image::Luma<_>, Vec<f32>> =
|
||||
match ImageBuffer::from_raw(
|
||||
nw as u32,
|
||||
nh as u32,
|
||||
mask.clone().into_raw_vec(),
|
||||
) {
|
||||
Some(buf) => buf,
|
||||
None => continue,
|
||||
};
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
|
||||
// rescale masks
|
||||
// rescale
|
||||
let mask_original = ops::descale_mask(
|
||||
mask_im,
|
||||
mask,
|
||||
nw as f32,
|
||||
nh as f32,
|
||||
image_width,
|
||||
image_height,
|
||||
);
|
||||
|
||||
// crop mask with bbox
|
||||
let mut mask_original = mask_original.into_luma8();
|
||||
|
||||
// crop mask
|
||||
for y in 0..image_height as usize {
|
||||
for x in 0..image_width as usize {
|
||||
if x < bbox.xmin() as usize
|
||||
|| x > bbox.xmax() as usize
|
||||
|| y < bbox.ymin() as usize
|
||||
|| y > bbox.ymax() as usize
|
||||
// || mask_original.get_pixel(x as u32, y as u32).0 < [127]
|
||||
{
|
||||
mask_original.put_pixel(
|
||||
x as u32,
|
||||
@ -367,23 +381,25 @@ impl YOLO {
|
||||
}
|
||||
|
||||
// get masks from image
|
||||
let mut masks: Vec<Polygon> = Vec::new();
|
||||
let contours: Vec<imageproc::contours::Contour<i32>> =
|
||||
imageproc::contours::find_contours_with_threshold(
|
||||
&mask_original,
|
||||
1,
|
||||
0,
|
||||
);
|
||||
contours.iter().for_each(|contour| {
|
||||
if contour.points.len() > 2 {
|
||||
masks.push(
|
||||
Polygon::default()
|
||||
.with_id(bbox.id())
|
||||
.with_points_imageproc(&contour.points)
|
||||
.with_name(bbox.name().cloned()),
|
||||
);
|
||||
}
|
||||
});
|
||||
y_polygons.extend(masks);
|
||||
let polygon = match contours
|
||||
.iter()
|
||||
.map(|x| {
|
||||
Polygon::default()
|
||||
.with_id(bbox.id())
|
||||
.with_points_imageproc(&x.points)
|
||||
.with_name(bbox.name().cloned())
|
||||
})
|
||||
.max_by(|x, y| x.area().total_cmp(&y.area()))
|
||||
{
|
||||
None => continue,
|
||||
Some(x) => x,
|
||||
};
|
||||
y_polygons.push(polygon);
|
||||
}
|
||||
y = y.with_polygons(&y_polygons);
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ pub struct YOLOPv2 {
|
||||
}
|
||||
|
||||
impl YOLOPv2 {
|
||||
pub fn new(options: &Options) -> Result<Self> {
|
||||
let engine = OrtEngine::new(options)?;
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().to_owned(),
|
||||
engine.height().to_owned(),
|
||||
@ -37,8 +37,14 @@ impl YOLOPv2 {
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32, 114.0)?;
|
||||
let xs_ = ops::normalize(xs_, 0.0, 255.0);
|
||||
let xs_ = ops::letterbox(
|
||||
xs,
|
||||
self.height() as u32,
|
||||
self.width() as u32,
|
||||
"bilinear",
|
||||
Some(114),
|
||||
)?;
|
||||
let xs_ = ops::normalize(xs_, 0., 255.);
|
||||
let ys = self.engine.run(&[xs_])?;
|
||||
self.postprocess(ys, xs)
|
||||
}
|
||||
@ -116,19 +122,19 @@ impl YOLOPv2 {
|
||||
let mask_da = mask_da.into_luma8();
|
||||
let mut y_polygons: Vec<Polygon> = Vec::new();
|
||||
let contours: Vec<imageproc::contours::Contour<i32>> =
|
||||
imageproc::contours::find_contours_with_threshold(&mask_da, 1);
|
||||
contours.iter().for_each(|contour| {
|
||||
if contour.border_type == imageproc::contours::BorderType::Outer
|
||||
&& contour.points.len() > 2
|
||||
{
|
||||
y_polygons.push(
|
||||
Polygon::default()
|
||||
.with_id(0)
|
||||
.with_points_imageproc(&contour.points)
|
||||
.with_name(Some("Drivable area".to_string())),
|
||||
);
|
||||
}
|
||||
});
|
||||
imageproc::contours::find_contours_with_threshold(&mask_da, 0);
|
||||
if let Some(polygon) = contours
|
||||
.iter()
|
||||
.map(|x| {
|
||||
Polygon::default()
|
||||
.with_id(0)
|
||||
.with_points_imageproc(&x.points)
|
||||
.with_name(Some("Drivable area".to_string()))
|
||||
})
|
||||
.max_by(|x, y| x.area().total_cmp(&y.area()))
|
||||
{
|
||||
y_polygons.push(polygon);
|
||||
};
|
||||
|
||||
// Lane line
|
||||
let x_ll = x_ll
|
||||
@ -150,21 +156,19 @@ impl YOLOPv2 {
|
||||
);
|
||||
let mask_ll = mask_ll.into_luma8();
|
||||
let contours: Vec<imageproc::contours::Contour<i32>> =
|
||||
imageproc::contours::find_contours_with_threshold(&mask_ll, 1);
|
||||
let mut masks: Vec<Polygon> = Vec::new();
|
||||
contours.iter().for_each(|contour| {
|
||||
if contour.border_type == imageproc::contours::BorderType::Outer
|
||||
&& contour.points.len() > 2
|
||||
{
|
||||
masks.push(
|
||||
Polygon::default()
|
||||
.with_id(1)
|
||||
.with_points_imageproc(&contour.points)
|
||||
.with_name(Some("Lane line".to_string())),
|
||||
);
|
||||
}
|
||||
});
|
||||
y_polygons.extend(masks);
|
||||
imageproc::contours::find_contours_with_threshold(&mask_ll, 0);
|
||||
if let Some(polygon) = contours
|
||||
.iter()
|
||||
.map(|x| {
|
||||
Polygon::default()
|
||||
.with_id(1)
|
||||
.with_points_imageproc(&x.points)
|
||||
.with_name(Some("Lane line".to_string()))
|
||||
})
|
||||
.max_by(|x, y| x.area().total_cmp(&y.area()))
|
||||
{
|
||||
y_polygons.push(polygon);
|
||||
};
|
||||
|
||||
// save
|
||||
ys.push(
|
||||
|
@ -31,7 +31,7 @@ impl std::fmt::Debug for Bbox {
|
||||
f.debug_struct("Bbox")
|
||||
.field("xyxy", &[self.x, self.y, self.xmax(), self.ymax()])
|
||||
.field("id", &self.id)
|
||||
.field("id_born", &self.id_born)
|
||||
// .field("id_born", &self.id_born)
|
||||
.field("name", &self.name)
|
||||
.field("confidence", &self.confidence)
|
||||
.finish()
|
||||
|