From 0e8d4f832aa1b0be4b0f4ad37acc374b07f42b32 Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:29:29 +0800 Subject: [PATCH] Add some eps (#108) --- Cargo.toml | 21 +- examples/ben2/main.rs | 4 +- examples/blip/main.rs | 2 +- examples/classifier/main.rs | 4 +- examples/clip/main.rs | 4 +- examples/db/main.rs | 4 +- examples/depth-anything/main.rs | 2 +- examples/depth-pro/main.rs | 6 +- examples/doclayout-yolo/main.rs | 2 +- examples/fast/main.rs | 6 +- examples/fastsam/main.rs | 4 +- examples/florence2/main.rs | 4 +- examples/grounding-dino/main.rs | 4 +- examples/linknet/main.rs | 6 +- examples/moondream2/main.rs | 8 +- examples/owlv2/main.rs | 4 +- examples/rmbg/main.rs | 4 +- examples/sam/main.rs | 6 +- examples/sam2/main.rs | 4 +- examples/sapiens/main.rs | 2 +- examples/slanet/main.rs | 4 +- examples/smolvlm/main.rs | 4 +- examples/svtr/main.rs | 4 +- examples/trocr/main.rs | 10 +- examples/ultralytics-rtdetr/main.rs | 4 +- examples/yolo-sam2/main.rs | 2 +- examples/yolo/main.rs | 10 +- examples/yoloe/main.rs | 4 +- src/inference/engine.rs | 482 +++++++++++++++++++++++++- src/io/dataloader.rs | 8 +- src/io/hub.rs | 6 +- src/models/sam/impl.rs | 7 +- src/models/trocr/impl.rs | 7 +- src/utils/config.rs | 501 ++++++++++++++++++++++++++++ src/utils/device.rs | 115 +++++-- src/utils/dtype.rs | 6 +- src/utils/ort_config.rs | 173 +++++++++- src/utils/scale.rs | 8 +- src/utils/task.rs | 8 +- src/viz/color.rs | 27 +- src/viz/colormap256.rs | 28 +- 41 files changed, 1360 insertions(+), 159 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ddf9127..8dc658b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "usls" edition = "2021" -version = "0.1.0-beta.3" +version = "0.1.0-beta.4" rust-version = "1.82" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" @@ -45,6 +45,7 @@ ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, fea ] } tokenizers = { version = "0.21.1" } paste = "1.0.15" +base64ct = "=1.7.3" [build-dependencies] prost-build = "0.13.5" @@ -53,11 +54,27 @@ prost-build = "0.13.5" argh = "0.1.13" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] } + [features] default = [ "ort-download-binaries" ] video = [ "dep:video-rs" ] ort-download-binaries = [ "ort", "ort/download-binaries" ] ort-load-dynamic = [ "ort", "ort/load-dynamic" ] cuda = [ "ort/cuda" ] -trt = [ "ort/tensorrt" ] +tensorrt = [ "ort/tensorrt" ] coreml = [ "ort/coreml" ] +openvino = [ "ort/openvino" ] +onednn = [ "ort/onednn" ] +directml = [ "ort/directml" ] +xnnpack = [ "ort/xnnpack" ] +cann = [ "ort/cann" ] +rknpu = [ "ort/rknpu" ] +acl = [ "ort/acl" ] +rocm = [ "ort/rocm" ] +nnapi = [ "ort/nnapi" ] +armnn = [ "ort/armnn" ] +tvm = [ "ort/tvm" ] +qnn = [ "ort/qnn" ] +migraphx = [ "ort/migraphx" ] +vitis = [ "ort/vitis" ] +azure = [ "ort/azure" ] diff --git a/examples/ben2/main.rs b/examples/ben2/main.rs index 8d44522..f1832c0 100644 --- a/examples/ben2/main.rs +++ b/examples/ben2/main.rs @@ -21,8 +21,8 @@ fn main() -> anyhow::Result<()> { // build model let config = Config::ben2_base() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = RMBG::new(config)?; diff --git a/examples/blip/main.rs b/examples/blip/main.rs index 6fd9456..e157238 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -21,7 +21,7 @@ fn main() -> anyhow::Result<()> { // build model let config = Config::blip_v1_base_caption() - .with_device_all(args.device.as_str().try_into()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = Blip::new(config)?; diff --git a/examples/classifier/main.rs b/examples/classifier/main.rs index d7bc9c1..822c7ca 100644 --- a/examples/classifier/main.rs +++ b/examples/classifier/main.rs @@ -46,8 +46,8 @@ fn main() -> anyhow::Result<()> { }; let config = config - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = ImageClassifier::try_from(config)?; diff --git a/examples/clip/main.rs b/examples/clip/main.rs index a4ceb98..658d8b2 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -29,8 +29,8 @@ fn main() -> Result<()> { // clip_vit_b32() // jina_clip_v1() // jina_clip_v2() - .with_dtype_all(args.dtype.as_str().try_into()?) - .with_device_all(args.device.as_str().try_into()?) + .with_dtype_all(args.dtype.parse()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = Clip::new(config)?; diff --git a/examples/db/main.rs b/examples/db/main.rs index 6bd7b16..6863a2c 100644 --- a/examples/db/main.rs +++ b/examples/db/main.rs @@ -47,9 +47,9 @@ fn main() -> Result<()> { // build model let config = match &args.model { Some(m) => Config::db().with_model_file(m), - None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.as_str().try_into()?), + None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.parse()?), } - .with_device_all(args.device.as_str().try_into()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = DB::new(config)?; diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index 2523063..8a851b7 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -18,7 +18,7 @@ fn main() -> Result<()> { // annotate let annotator = - Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into())); + Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?)); for (x, y) in xs.iter().zip(ys.iter()) { annotator.annotate(x, y)?.save(format!( "{}.jpg", diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs index 84002e6..1482bf3 100644 --- a/examples/depth-pro/main.rs +++ b/examples/depth-pro/main.rs @@ -24,8 +24,8 @@ fn main() -> Result<()> { // model let config = Config::depth_pro() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = DepthPro::new(config)?; @@ -38,7 +38,7 @@ fn main() -> Result<()> { // annotate let annotator = - Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into())); + Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?)); for (x, y) in xs.iter().zip(ys.iter()) { annotator.annotate(x, y)?.save(format!( "{}.jpg", diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs index 52eb13d..8062500 100644 --- a/examples/doclayout-yolo/main.rs +++ b/examples/doclayout-yolo/main.rs @@ -19,7 +19,7 @@ fn main() -> Result<()> { // build model let config = Config::doclayout_yolo_docstructbench() - .with_model_device(args.device.as_str().try_into()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/examples/fast/main.rs b/examples/fast/main.rs index 875b466..316333b 100644 --- a/examples/fast/main.rs +++ b/examples/fast/main.rs @@ -26,7 +26,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = match args.scale.as_str().try_into()? { + let config = match args.scale.parse()? { Scale::T => Config::fast_tiny(), Scale::S => Config::fast_small(), Scale::B => Config::fast_base(), @@ -34,8 +34,8 @@ fn main() -> Result<()> { }; let mut model = DB::new( config - .with_dtype_all(args.dtype.as_str().try_into()?) - .with_device_all(args.device.as_str().try_into()?) + .with_dtype_all(args.dtype.parse()?) + .with_device_all(args.device.parse()?) .commit()?, )?; diff --git a/examples/fastsam/main.rs b/examples/fastsam/main.rs index f5aa616..73afe55 100644 --- a/examples/fastsam/main.rs +++ b/examples/fastsam/main.rs @@ -23,8 +23,8 @@ fn main() -> Result<()> { // build model let config = Config::fastsam_s() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 3ab421f..9102958 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -26,8 +26,8 @@ fn main() -> Result<()> { // build model let config = Config::florence2_base() - .with_dtype_all(args.dtype.as_str().try_into()?) - .with_device_all(args.device.as_str().try_into()?) + .with_dtype_all(args.dtype.parse()?) + .with_device_all(args.device.parse()?) .with_batch_size_all(xs.len()) .commit()?; let mut model = Florence2::new(config)?; diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index 1050d95..164af72 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -46,8 +46,8 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); let config = Config::grounding_dino_tiny() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) .with_class_confs(&[0.25]) .with_text_confs(&[0.25]) diff --git a/examples/linknet/main.rs b/examples/linknet/main.rs index b8a4523..937408f 100644 --- a/examples/linknet/main.rs +++ b/examples/linknet/main.rs @@ -27,7 +27,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = match args.scale.as_str().try_into()? { + let config = match args.scale.parse()? { Scale::T => Config::linknet_r18(), Scale::S => Config::linknet_r34(), Scale::B => Config::linknet_r50(), @@ -35,8 +35,8 @@ fn main() -> Result<()> { }; let mut model = DB::new( config - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?, )?; diff --git a/examples/moondream2/main.rs b/examples/moondream2/main.rs index f23e9cf..7cf9810 100644 --- a/examples/moondream2/main.rs +++ b/examples/moondream2/main.rs @@ -39,13 +39,13 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = match args.scale.as_str().try_into()? { + let config = match args.scale.parse()? { Scale::Billion(0.5) => Config::moondream2_0_5b(), Scale::Billion(2.) => Config::moondream2_2b(), _ => unimplemented!(), } - .with_dtype_all(args.dtype.as_str().try_into()?) - .with_device_all(args.device.as_str().try_into()?) + .with_dtype_all(args.dtype.parse()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = Moondream2::new(config)?; @@ -54,7 +54,7 @@ fn main() -> Result<()> { let xs = DataLoader::try_read_n(&args.source)?; // run with task - let task: Task = args.task.as_str().try_into()?; + let task: Task = args.task.parse()?; let ys = model.forward(&xs, &task)?; // annotate diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs index 0037d50..f6dac0a 100644 --- a/examples/owlv2/main.rs +++ b/examples/owlv2/main.rs @@ -49,8 +49,8 @@ fn main() -> Result<()> { // config let config = Config::owlv2_base_ensemble() // owlv2_base() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) .commit()?; let mut model = OWLv2::new(config)?; diff --git a/examples/rmbg/main.rs b/examples/rmbg/main.rs index 0d685a3..7fcf194 100644 --- a/examples/rmbg/main.rs +++ b/examples/rmbg/main.rs @@ -31,8 +31,8 @@ fn main() -> anyhow::Result<()> { // build model let config = config - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = RMBG::new(config)?; diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 9a698d8..96b994e 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -28,9 +28,9 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // Build model - let config = match args.kind.as_str().try_into()? { + let config = match args.kind.parse()? { SamKind::Sam => Config::sam_v1_base(), - SamKind::Sam2 => match args.scale.as_str().try_into()? { + SamKind::Sam2 => match args.scale.parse()? { Scale::T => Config::sam2_tiny(), Scale::S => Config::sam2_small(), Scale::B => Config::sam2_base_plus(), @@ -40,7 +40,7 @@ fn main() -> Result<()> { SamKind::SamHq => Config::sam_hq_tiny(), SamKind::EdgeSam => Config::edge_sam_3x(), } - .with_device_all(args.device.as_str().try_into()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = SAM::new(config)?; diff --git a/examples/sam2/main.rs b/examples/sam2/main.rs index 48eef6e..6b75c1d 100644 --- a/examples/sam2/main.rs +++ b/examples/sam2/main.rs @@ -25,14 +25,14 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // Build model - let config = match args.scale.as_str().try_into()? { + let config = match args.scale.parse()? { Scale::T => Config::sam2_1_tiny(), Scale::S => Config::sam2_1_small(), Scale::B => Config::sam2_1_base_plus(), Scale::L => Config::sam2_1_large(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale), } - .with_device_all(args.device.as_str().try_into()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = SAM2::new(config)?; diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs index caf2d17..4324cc5 100644 --- a/examples/sapiens/main.rs +++ b/examples/sapiens/main.rs @@ -18,7 +18,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build let config = Config::sapiens_seg_0_3b() - .with_model_device(args.device.as_str().try_into()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = Sapiens::new(config)?; diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs index 71a0fc7..90a68ce 100644 --- a/examples/slanet/main.rs +++ b/examples/slanet/main.rs @@ -27,8 +27,8 @@ fn main() -> Result<()> { // build model let config = Config::slanet_lcnet_v2_mobile_ch() - .with_model_device(args.device.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.parse()?) + .with_model_dtype(args.dtype.parse()?) .commit()?; let mut model = SLANet::new(config)?; diff --git a/examples/smolvlm/main.rs b/examples/smolvlm/main.rs index 87d58b0..4d2b99b 100644 --- a/examples/smolvlm/main.rs +++ b/examples/smolvlm/main.rs @@ -29,12 +29,12 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = match args.scale.as_str().try_into()? { + let config = match args.scale.parse()? { Scale::Million(256.) => Config::smolvlm_256m(), Scale::Million(500.) => Config::smolvlm_500m(), _ => unimplemented!(), } - .with_device_all(args.device.as_str().try_into()?) + .with_device_all(args.device.parse()?) .commit()?; let mut model = SmolVLM::new(config)?; diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index d0d9d53..492b2c7 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -32,8 +32,8 @@ fn main() -> Result<()> { // ppocr_rec_v4_en() // repsvtr_ch() .with_model_ixx(0, 3, args.max_text_length.into()) - .with_model_device(args.device.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.parse()?) + .with_model_dtype(args.dtype.parse()?) .commit()?; let mut model = SVTR::new(config)?; diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs index 8e5fcc8..f215989 100644 --- a/examples/trocr/main.rs +++ b/examples/trocr/main.rs @@ -38,19 +38,19 @@ fn main() -> anyhow::Result<()> { ])?; // build model - let config = match args.scale.as_str().try_into()? { - Scale::S => match args.kind.as_str().try_into()? { + let config = match args.scale.parse()? { + Scale::S => match args.kind.parse()? { TrOCRKind::Printed => Config::trocr_small_printed(), TrOCRKind::HandWritten => Config::trocr_small_handwritten(), }, - Scale::B => match args.kind.as_str().try_into()? { + Scale::B => match args.kind.parse()? { TrOCRKind::Printed => Config::trocr_base_printed(), TrOCRKind::HandWritten => Config::trocr_base_handwritten(), }, x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), } - .with_device_all(args.device.as_str().try_into()?) - .with_dtype_all(args.dtype.as_str().try_into()?) + .with_device_all(args.device.parse()?) + .with_dtype_all(args.dtype.parse()?) .commit()?; let mut model = TrOCR::new(config)?; diff --git a/examples/ultralytics-rtdetr/main.rs b/examples/ultralytics-rtdetr/main.rs index a67ebde..1506162 100644 --- a/examples/ultralytics-rtdetr/main.rs +++ b/examples/ultralytics-rtdetr/main.rs @@ -23,8 +23,8 @@ fn main() -> Result<()> { // build model let config = Config::ultralytics_rtdetr_l() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs index 5e9fb53..fe15d53 100644 --- a/examples/yolo-sam2/main.rs +++ b/examples/yolo-sam2/main.rs @@ -27,7 +27,7 @@ fn main() -> Result<()> { let options_yolo = Config::yolo_detect() .with_scale(Scale::N) .with_version(8.into()) - .with_model_device(args.device.as_str().try_into()?) + .with_model_device(args.device.parse()?) .commit()?; let mut yolo = YOLO::new(options_yolo)?; diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index e2f9007..1960a33 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -132,12 +132,12 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); let mut config = Config::yolo() .with_model_file(&args.model.unwrap_or_default()) - .with_task(args.task.as_str().try_into()?) + .with_task(args.task.parse()?) .with_version(args.ver.try_into()?) - .with_scale(args.scale.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_model_trt_fp16(args.trt_fp16) + .with_scale(args.scale.parse()?) + .with_model_dtype(args.dtype.parse()?) + .with_model_device(args.device.parse()?) + .with_model_tensorrt_fp16(args.trt_fp16) .with_model_ixx( 0, 0, diff --git a/examples/yoloe/main.rs b/examples/yoloe/main.rs index 3fa3b2f..c6e4dbe 100644 --- a/examples/yoloe/main.rs +++ b/examples/yoloe/main.rs @@ -28,8 +28,8 @@ fn main() -> Result<()> { // yoloe_11s_seg_pf() // yoloe_11m_seg_pf() // yoloe_11l_seg_pf() - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + .with_model_dtype(args.dtype.as_str().parse()?) + .with_model_device(args.device.as_str().parse()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/src/inference/engine.rs b/src/inference/engine.rs index 2186660..6206f52 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -66,7 +66,6 @@ pub struct Engine { pub file: String, pub spec: String, pub device: Device, - pub trt_fp16: bool, #[args(inc)] pub iiixs: Vec, #[args(aka = "parameters")] @@ -77,7 +76,50 @@ pub struct Engine { pub onnx: Option, pub ts: Ts, pub num_dry_run: usize, + + // global pub graph_opt_level: Option, + pub num_intra_threads: Option, + pub num_inter_threads: Option, + + // cpu + pub cpu_arena_allocator: bool, + + // tensorrt + pub tensorrt_fp16: bool, + pub tensorrt_engine_cache: bool, + pub tensorrt_timing_cache: bool, + + // openvino + pub openvino_dynamic_shapes: bool, + pub openvino_opencl_throttling: bool, + pub openvino_qdq_optimizer: bool, + pub openvino_num_threads: Option, + + // onednn + pub onednn_arena_allocator: bool, + + // coreml + pub coreml_static_input_shapes: bool, + pub coreml_subgraph_running: bool, + + // cann + pub cann_graph_inference: bool, + pub cann_dump_graphs: bool, + pub cann_dump_om_model: bool, + + // nnapi + pub nnapi_cpu_only: bool, + pub nnapi_disable_cpu: bool, + pub nnapi_fp16: bool, + pub nnapi_nchw: bool, + + // armnn + pub armnn_arena_allocator: bool, + + // migraphx + pub migraphx_fp16: bool, + pub migraphx_exhaustive_tune: bool, } impl Default for Engine { @@ -85,7 +127,6 @@ impl Default for Engine { Self { file: Default::default(), device: Device::Cpu(0), - trt_fp16: false, spec: Default::default(), iiixs: Default::default(), num_dry_run: 3, @@ -94,7 +135,40 @@ impl Default for Engine { inputs_minoptmax: vec![], onnx: None, ts: Ts::default(), + // global graph_opt_level: None, + num_intra_threads: None, + num_inter_threads: None, + // cpu + cpu_arena_allocator: true, + // openvino + openvino_dynamic_shapes: true, + openvino_opencl_throttling: true, + openvino_qdq_optimizer: true, + openvino_num_threads: None, + // onednn + onednn_arena_allocator: true, + // coreml + coreml_static_input_shapes: false, + coreml_subgraph_running: true, + // tensorrt + tensorrt_fp16: true, + tensorrt_engine_cache: true, + tensorrt_timing_cache: false, + // cann + cann_graph_inference: true, + cann_dump_graphs: false, + cann_dump_om_model: false, + // nnapi + nnapi_cpu_only: false, + nnapi_disable_cpu: false, + nnapi_fp16: true, + nnapi_nchw: false, + // armnn + armnn_arena_allocator: true, + // migraphx + migraphx_fp16: true, + migraphx_exhaustive_tune: false, } } } @@ -106,9 +180,40 @@ impl Engine { spec: config.spec.clone(), iiixs: config.iiixs.clone(), device: config.device, - trt_fp16: config.trt_fp16, num_dry_run: config.num_dry_run, + // global graph_opt_level: config.graph_opt_level, + num_intra_threads: config.num_intra_threads, + num_inter_threads: config.num_inter_threads, + // cpu + cpu_arena_allocator: config.cpu_arena_allocator, + // openvino + openvino_dynamic_shapes: config.openvino_dynamic_shapes, + openvino_opencl_throttling: config.openvino_opencl_throttling, + openvino_qdq_optimizer: config.openvino_qdq_optimizer, + openvino_num_threads: config.openvino_num_threads, + // coreml + coreml_static_input_shapes: config.coreml_static_input_shapes, + coreml_subgraph_running: config.coreml_subgraph_running, + // tensorrt + tensorrt_fp16: config.tensorrt_fp16, + tensorrt_engine_cache: config.tensorrt_engine_cache, + tensorrt_timing_cache: config.tensorrt_timing_cache, + // cann + cann_graph_inference: config.cann_graph_inference, + cann_dump_graphs: config.cann_dump_graphs, + cann_dump_om_model: config.cann_dump_om_model, + // nnapi + nnapi_cpu_only: config.nnapi_cpu_only, + nnapi_disable_cpu: config.nnapi_disable_cpu, + nnapi_fp16: config.nnapi_fp16, + nnapi_nchw: config.nnapi_nchw, + // armnn + armnn_arena_allocator: config.armnn_arena_allocator, + // migraphx + migraphx_fp16: config.migraphx_fp16, + migraphx_exhaustive_tune: config.migraphx_exhaustive_tune, + ..Default::default() } .build() @@ -338,17 +443,20 @@ impl Engine { let compile_help = "Please compile ONNXRuntime with #EP"; let feature_help = "#EP EP requires the features: `#FEATURE`. \ \nConsider enabling them by passing, e.g., `--features #FEATURE`"; + let n_threads_available = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1); match self.device { Device::TensorRt(id) => { - #[cfg(not(feature = "trt"))] + #[cfg(not(feature = "tensorrt"))] { anyhow::bail!(feature_help .replace("#EP", "TensorRT") - .replace("#FEATURE", "trt")); + .replace("#FEATURE", "tensorrt")); } - #[cfg(feature = "trt")] + #[cfg(feature = "tensorrt")] { // generate shapes let mut spec_min = String::new(); @@ -379,13 +487,16 @@ impl Engine { spec_max += &s_max; } - let p = crate::Dir::Cache.crate_dir_default_with_subs(&["trt-cache"])?; let ep = ort::execution_providers::TensorRTExecutionProvider::default() .with_device_id(id as i32) - .with_fp16(self.trt_fp16) - .with_engine_cache(true) - .with_engine_cache_path(p.to_str().unwrap()) - .with_timing_cache(false) + .with_fp16(self.tensorrt_fp16) + .with_engine_cache(self.tensorrt_engine_cache) + .with_timing_cache(self.tensorrt_timing_cache) + .with_engine_cache_path( + crate::Dir::Cache + .crate_dir_default_with_subs(&["caches", "tensorrt"])? + .display(), + ) .with_profile_min_shapes(spec_min) .with_profile_opt_shapes(spec_opt) .with_profile_max_shapes(spec_max); @@ -427,7 +538,7 @@ impl Engine { } } } - Device::CoreMl(id) => { + Device::CoreMl => { #[cfg(not(feature = "coreml"))] { anyhow::bail!(feature_help @@ -439,12 +550,12 @@ impl Engine { let ep = ort::execution_providers::CoreMLExecutionProvider::default() .with_model_cache_dir( crate::Dir::Cache - .crate_dir_default_with_subs(&["coreml-cache"])? + .crate_dir_default_with_subs(&["caches", "coreml"])? .display(), ) + .with_static_input_shapes(self.coreml_static_input_shapes) + .with_subgraphs(self.coreml_subgraph_running) .with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All) - .with_static_input_shapes(false) - .with_subgraphs(true) .with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram) .with_specialization_strategy( ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction, @@ -459,9 +570,345 @@ impl Engine { } } } + Device::OpenVino(dt) => { + #[cfg(not(feature = "openvino"))] + { + anyhow::bail!(feature_help + .replace("#EP", "OpenVINO") + .replace("#FEATURE", "openvino")); + } + + #[cfg(feature = "openvino")] + { + let ep = ort::execution_providers::OpenVINOExecutionProvider::default() + .with_device_type(dt) + .with_num_threads(self.openvino_num_threads.unwrap_or(n_threads_available)) + .with_dynamic_shapes(self.openvino_dynamic_shapes) + .with_opencl_throttling(self.openvino_opencl_throttling) + .with_qdq_optimizer(self.openvino_qdq_optimizer) + .with_cache_dir( + crate::Dir::Cache + .crate_dir_default_with_subs(&["caches", "openvino"])? + .display() + .to_string(), + ); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register OpenVINO: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "OpenVINO")), + } + } + } + Device::DirectMl(id) => { + #[cfg(not(feature = "directml"))] + { + anyhow::bail!(feature_help + .replace("#EP", "DirectML") + .replace("#FEATURE", "directml")); + } + #[cfg(feature = "directml")] + { + let ep = ort::execution_providers::DirectMLExecutionProvider::default() + .with_device_id(id as i32); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register DirectML: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "DirectML")), + } + } + } + Device::Xnnpack => { + #[cfg(not(feature = "xnnpack"))] + { + anyhow::bail!(feature_help + .replace("#EP", "XNNPack") + .replace("#FEATURE", "xnnpack")); + } + #[cfg(feature = "xnnpack")] + { + let ep = ort::execution_providers::XNNPACKExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register XNNPack: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "XNNPack")), + } + } + } + Device::Cann(id) => { + #[cfg(not(feature = "cann"))] + { + anyhow::bail!(feature_help + .replace("#EP", "CANN") + .replace("#FEATURE", "cann")); + } + #[cfg(feature = "cann")] + { + let ep = ort::execution_providers::CANNExecutionProvider::default() + .with_device_id(id as i32) + .with_cann_graph(self.cann_graph_inference) + .with_dump_graphs(self.cann_dump_graphs) + .with_dump_om_model(self.cann_dump_om_model); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register CANN: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "CANN")), + } + } + } + Device::RkNpu => { + #[cfg(not(feature = "rknpu"))] + { + anyhow::bail!(feature_help + .replace("#EP", "RKNPU") + .replace("#FEATURE", "rknpu")); + } + #[cfg(feature = "rknpu")] + { + let ep = ort::execution_providers::RKNPUExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register RKNPU: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "RKNPU")), + } + } + } + Device::OneDnn => { + #[cfg(not(feature = "onednn"))] + { + anyhow::bail!(feature_help + .replace("#EP", "oneDNN") + .replace("#FEATURE", "onednn")); + } + #[cfg(feature = "onednn")] + { + let ep = ort::execution_providers::OneDNNExecutionProvider::default() + .with_arena_allocator(self.onednn_arena_allocator); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register oneDNN: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "oneDNN")), + } + } + } + Device::Acl => { + #[cfg(not(feature = "acl"))] + { + anyhow::bail!(feature_help + .replace("#EP", "ArmACL") + .replace("#FEATURE", "acl")); + } + #[cfg(feature = "acl")] + { + let ep = ort::execution_providers::ACLExecutionProvider::default() + .with_fast_math(true); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register ArmACL: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "ArmACL")), + } + } + } + Device::Rocm(id) => { + #[cfg(not(feature = "rocm"))] + { + anyhow::bail!(feature_help + .replace("#EP", "ROCm") + .replace("#FEATURE", "rocm")); + } + #[cfg(feature = "rocm")] + { + let ep = ort::execution_providers::ROCmExecutionProvider::default() + .with_device_id(id as _); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register ROCm: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "ROCm")), + } + } + } + Device::NnApi => { + #[cfg(not(feature = "nnapi"))] + { + anyhow::bail!(feature_help + .replace("#EP", "NNAPI") + .replace("#FEATURE", "nnapi")); + } + #[cfg(feature = "nnapi")] + { + let ep = ort::execution_providers::NNAPIExecutionProvider::default() + .with_fp16(self.nnapi_fp16) + .with_nchw(self.nnapi_nchw) + .with_cpu_only(self.nnapi_cpu_only) + .with_disable_cpu(self.nnapi_disable_cpu); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register NNAPI: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "NNAPI")), + } + } + } + Device::ArmNn => { + #[cfg(not(feature = "armnn"))] + { + anyhow::bail!(feature_help + .replace("#EP", "ArmNN") + .replace("#FEATURE", "armnn")); + } + #[cfg(feature = "armnn")] + { + let ep = ort::execution_providers::ArmNNExecutionProvider::default() + .with_arena_allocator(self.armnn_arena_allocator); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register ArmNN: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "ArmNN")), + } + } + } + Device::Tvm => { + #[cfg(not(feature = "tvm"))] + { + anyhow::bail!(feature_help + .replace("#EP", "TVM") + .replace("#FEATURE", "tvm")); + } + #[cfg(feature = "tvm")] + { + let ep = ort::execution_providers::TVMExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register TVM: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "TVM")), + } + } + } + Device::Qnn(id) => { + #[cfg(not(feature = "qnn"))] + { + anyhow::bail!(feature_help + .replace("#EP", "QNN") + .replace("#FEATURE", "qnn")); + } + #[cfg(feature = "qnn")] + { + let ep = ort::execution_providers::QNNExecutionProvider::default() + .with_device_id(id as _); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register QNN: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "QNN")), + } + } + } + Device::MiGraphX(id) => { + #[cfg(not(feature = "migraphx"))] + { + anyhow::bail!(feature_help + .replace("#EP", "MIGraphX") + .replace("#FEATURE", "migraphx")); + } + #[cfg(feature = "migraphx")] + { + let ep = ort::execution_providers::MIGraphXExecutionProvider::default() + .with_device_id(id as _) + .with_fp16(self.migraphx_fp16) + .with_exhaustive_tune(self.migraphx_exhaustive_tune); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register MIGraphX: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "MIGraphX")), + } + } + } + Device::Vitis => { + #[cfg(not(feature = "vitis"))] + { + anyhow::bail!(feature_help + .replace("#EP", "VitisAI") + .replace("#FEATURE", "vitis")); + } + #[cfg(feature = "vitis")] + { + let ep = ort::execution_providers::VitisAIExecutionProvider::default() + .with_cache_dir( + crate::Dir::Cache + .crate_dir_default_with_subs(&["caches", "vitis"])? + .display() + .to_string(), + ); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register VitisAI: {}", err) + })?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "VitisAI")), + } + } + } + Device::Azure => { + #[cfg(not(feature = "azure"))] + { + anyhow::bail!(feature_help + .replace("#EP", "Azure") + .replace("#FEATURE", "azure")); + } + #[cfg(feature = "azure")] + { + let ep = ort::execution_providers::AzureExecutionProvider::default(); + match ep.is_available() { + Ok(true) => { + ep.register(&mut builder).map_err(|err| { + anyhow::anyhow!("Failed to register Azure: {}", err) + })?; + builder = builder.with_extensions()?; + } + _ => anyhow::bail!(compile_help.replace("#EP", "Azure")), + } + } + } _ => { let ep = ort::execution_providers::CPUExecutionProvider::default() - .with_arena_allocator(true); + .with_arena_allocator(self.cpu_arena_allocator); match ep.is_available() { Ok(true) => { ep.register(&mut builder) @@ -481,7 +928,8 @@ impl Engine { }; let session = builder .with_optimization_level(graph_opt_level)? - .with_intra_threads(std::thread::available_parallelism()?.get())? + .with_intra_threads(self.num_intra_threads.unwrap_or(n_threads_available))? + .with_inter_threads(self.num_inter_threads.unwrap_or(2))? .commit_from_file(self.file())?; Ok(session) diff --git a/src/io/dataloader.rs b/src/io/dataloader.rs index 19aed4e..4292dcf 100644 --- a/src/io/dataloader.rs +++ b/src/io/dataloader.rs @@ -5,6 +5,7 @@ use log::{info, warn}; use rayon::prelude::*; use std::collections::VecDeque; use std::path::{Path, PathBuf}; +use std::str::FromStr; use std::sync::mpsc; #[cfg(feature = "video")] use video_rs::{Decoder, Url}; @@ -80,9 +81,10 @@ impl std::fmt::Debug for DataLoader { } } -impl TryFrom<&str> for DataLoader { - type Error = anyhow::Error; - fn try_from(source: &str) -> Result { +impl FromStr for DataLoader { + type Err = anyhow::Error; + + fn from_str(source: &str) -> Result { Self::new(source) } } diff --git a/src/io/hub.rs b/src/io/hub.rs index cbd6bc2..9864fbc 100644 --- a/src/io/hub.rs +++ b/src/io/hub.rs @@ -460,7 +460,7 @@ impl Hub { fn cache_file(owner: &str, repo: &str) -> String { let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_"); let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_"); - format!(".cache-releases-{}-{}.json", safe_owner, safe_repo) + format!("releases-{}-{}.json", safe_owner, safe_repo) } fn get_releases( @@ -470,7 +470,9 @@ impl Hub { to: &Dir, ttl: &Duration, ) -> Result> { - let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo)); + let cache = to + .crate_dir_default_with_subs(&["caches"])? + .join(Self::cache_file(owner, repo)); let is_file_expired = Self::is_file_expired(&cache, ttl)?; let body = if is_file_expired { let gh_api_release = format!( diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs index 7a76e77..2abea4f 100644 --- a/src/models/sam/impl.rs +++ b/src/models/sam/impl.rs @@ -2,6 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; use rand::{prelude::*, rng}; +use std::str::FromStr; use crate::{ elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y, @@ -16,10 +17,10 @@ pub enum SamKind { EdgeSam, } -impl TryFrom<&str> for SamKind { - type Error = anyhow::Error; +impl FromStr for SamKind { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "sam" => Ok(Self::Sam), "sam2" => Ok(Self::Sam2), diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs index 39701fd..f1a3679 100644 --- a/src/models/trocr/impl.rs +++ b/src/models/trocr/impl.rs @@ -2,6 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; +use std::str::FromStr; use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y}; @@ -11,10 +12,10 @@ pub enum TrOCRKind { HandWritten, } -impl TryFrom<&str> for TrOCRKind { - type Error = anyhow::Error; +impl FromStr for TrOCRKind { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "printed" => Ok(Self::Printed), "handwritten" | "hand-written" => Ok(Self::HandWritten), diff --git a/src/utils/config.rs b/src/utils/config.rs index 3f9895f..b6799db 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -261,6 +261,507 @@ impl Config { self } + + pub fn with_graph_opt_level_all(mut self, level: u8) -> Self { + self.visual = self.visual.with_graph_opt_level(level); + self.textual = self.textual.with_graph_opt_level(level); + self.model = self.model.with_graph_opt_level(level); + self.encoder = self.encoder.with_graph_opt_level(level); + self.decoder = self.decoder.with_graph_opt_level(level); + self.visual_encoder = self.visual_encoder.with_graph_opt_level(level); + self.textual_encoder = self.textual_encoder.with_graph_opt_level(level); + self.visual_decoder = self.visual_decoder.with_graph_opt_level(level); + self.textual_decoder = self.textual_decoder.with_graph_opt_level(level); + self.textual_decoder_merged = self.textual_decoder_merged.with_graph_opt_level(level); + self.size_encoder = self.size_encoder.with_graph_opt_level(level); + self.size_decoder = self.size_decoder.with_graph_opt_level(level); + self.coord_encoder = self.coord_encoder.with_graph_opt_level(level); + self.coord_decoder = self.coord_decoder.with_graph_opt_level(level); + self.visual_projection = self.visual_projection.with_graph_opt_level(level); + self.textual_projection = self.textual_projection.with_graph_opt_level(level); + self + } + + pub fn with_num_intra_threads_all(mut self, num_threads: usize) -> Self { + self.visual = self.visual.with_num_intra_threads(num_threads); + self.textual = self.textual.with_num_intra_threads(num_threads); + self.model = self.model.with_num_intra_threads(num_threads); + self.encoder = self.encoder.with_num_intra_threads(num_threads); + self.decoder = self.decoder.with_num_intra_threads(num_threads); + self.visual_encoder = self.visual_encoder.with_num_intra_threads(num_threads); + self.textual_encoder = self.textual_encoder.with_num_intra_threads(num_threads); + self.visual_decoder = self.visual_decoder.with_num_intra_threads(num_threads); + self.textual_decoder = self.textual_decoder.with_num_intra_threads(num_threads); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_num_intra_threads(num_threads); + self.size_encoder = self.size_encoder.with_num_intra_threads(num_threads); + self.size_decoder = self.size_decoder.with_num_intra_threads(num_threads); + self.coord_encoder = self.coord_encoder.with_num_intra_threads(num_threads); + self.coord_decoder = self.coord_decoder.with_num_intra_threads(num_threads); + self.visual_projection = self.visual_projection.with_num_intra_threads(num_threads); + self.textual_projection = self.textual_projection.with_num_intra_threads(num_threads); + self + } + + pub fn with_num_inter_threads_all(mut self, num_threads: usize) -> Self { + self.visual = self.visual.with_num_inter_threads(num_threads); + self.textual = self.textual.with_num_inter_threads(num_threads); + self.model = self.model.with_num_inter_threads(num_threads); + self.encoder = self.encoder.with_num_inter_threads(num_threads); + self.decoder = self.decoder.with_num_inter_threads(num_threads); + self.visual_encoder = self.visual_encoder.with_num_inter_threads(num_threads); + self.textual_encoder = self.textual_encoder.with_num_inter_threads(num_threads); + self.visual_decoder = self.visual_decoder.with_num_inter_threads(num_threads); + self.textual_decoder = self.textual_decoder.with_num_inter_threads(num_threads); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_num_inter_threads(num_threads); + self.size_encoder = self.size_encoder.with_num_inter_threads(num_threads); + self.size_decoder = self.size_decoder.with_num_inter_threads(num_threads); + self.coord_encoder = self.coord_encoder.with_num_inter_threads(num_threads); + self.coord_decoder = self.coord_decoder.with_num_inter_threads(num_threads); + self.visual_projection = self.visual_projection.with_num_inter_threads(num_threads); + self.textual_projection = self.textual_projection.with_num_inter_threads(num_threads); + self + } + + pub fn with_cpu_arena_allocator_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_cpu_arena_allocator(x); + self.textual = self.textual.with_cpu_arena_allocator(x); + self.model = self.model.with_cpu_arena_allocator(x); + self.encoder = self.encoder.with_cpu_arena_allocator(x); + self.decoder = self.decoder.with_cpu_arena_allocator(x); + self.visual_encoder = self.visual_encoder.with_cpu_arena_allocator(x); + self.textual_encoder = self.textual_encoder.with_cpu_arena_allocator(x); + self.visual_decoder = self.visual_decoder.with_cpu_arena_allocator(x); + self.textual_decoder = self.textual_decoder.with_cpu_arena_allocator(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_cpu_arena_allocator(x); + self.size_encoder = self.size_encoder.with_cpu_arena_allocator(x); + self.size_decoder = self.size_decoder.with_cpu_arena_allocator(x); + self.coord_encoder = self.coord_encoder.with_cpu_arena_allocator(x); + self.coord_decoder = self.coord_decoder.with_cpu_arena_allocator(x); + self.visual_projection = self.visual_projection.with_cpu_arena_allocator(x); + self.textual_projection = self.textual_projection.with_cpu_arena_allocator(x); + self + } + + pub fn with_openvino_dynamic_shapes_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_openvino_dynamic_shapes(x); + self.textual = self.textual.with_openvino_dynamic_shapes(x); + self.model = self.model.with_openvino_dynamic_shapes(x); + self.encoder = self.encoder.with_openvino_dynamic_shapes(x); + self.decoder = self.decoder.with_openvino_dynamic_shapes(x); + self.visual_encoder = self.visual_encoder.with_openvino_dynamic_shapes(x); + self.textual_encoder = self.textual_encoder.with_openvino_dynamic_shapes(x); + self.visual_decoder = self.visual_decoder.with_openvino_dynamic_shapes(x); + self.textual_decoder = self.textual_decoder.with_openvino_dynamic_shapes(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_dynamic_shapes(x); + self.size_encoder = self.size_encoder.with_openvino_dynamic_shapes(x); + self.size_decoder = self.size_decoder.with_openvino_dynamic_shapes(x); + self.coord_encoder = self.coord_encoder.with_openvino_dynamic_shapes(x); + self.coord_decoder = self.coord_decoder.with_openvino_dynamic_shapes(x); + self.visual_projection = self.visual_projection.with_openvino_dynamic_shapes(x); + self.textual_projection = self.textual_projection.with_openvino_dynamic_shapes(x); + self + } + + pub fn with_openvino_opencl_throttling_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_openvino_opencl_throttling(x); + self.textual = self.textual.with_openvino_opencl_throttling(x); + self.model = self.model.with_openvino_opencl_throttling(x); + self.encoder = self.encoder.with_openvino_opencl_throttling(x); + self.decoder = self.decoder.with_openvino_opencl_throttling(x); + self.visual_encoder = self.visual_encoder.with_openvino_opencl_throttling(x); + self.textual_encoder = self.textual_encoder.with_openvino_opencl_throttling(x); + self.visual_decoder = self.visual_decoder.with_openvino_opencl_throttling(x); + self.textual_decoder = self.textual_decoder.with_openvino_opencl_throttling(x); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_openvino_opencl_throttling(x); + self.size_encoder = self.size_encoder.with_openvino_opencl_throttling(x); + self.size_decoder = self.size_decoder.with_openvino_opencl_throttling(x); + self.coord_encoder = self.coord_encoder.with_openvino_opencl_throttling(x); + self.coord_decoder = self.coord_decoder.with_openvino_opencl_throttling(x); + self.visual_projection = self.visual_projection.with_openvino_opencl_throttling(x); + self.textual_projection = self.textual_projection.with_openvino_opencl_throttling(x); + self + } + + pub fn with_openvino_qdq_optimizer_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_openvino_qdq_optimizer(x); + self.textual = self.textual.with_openvino_qdq_optimizer(x); + self.model = self.model.with_openvino_qdq_optimizer(x); + self.encoder = self.encoder.with_openvino_qdq_optimizer(x); + self.decoder = self.decoder.with_openvino_qdq_optimizer(x); + self.visual_encoder = self.visual_encoder.with_openvino_qdq_optimizer(x); + self.textual_encoder = self.textual_encoder.with_openvino_qdq_optimizer(x); + self.visual_decoder = self.visual_decoder.with_openvino_qdq_optimizer(x); + self.textual_decoder = self.textual_decoder.with_openvino_qdq_optimizer(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_qdq_optimizer(x); + self.size_encoder = self.size_encoder.with_openvino_qdq_optimizer(x); + self.size_decoder = self.size_decoder.with_openvino_qdq_optimizer(x); + self.coord_encoder = self.coord_encoder.with_openvino_qdq_optimizer(x); + self.coord_decoder = self.coord_decoder.with_openvino_qdq_optimizer(x); + self.visual_projection = self.visual_projection.with_openvino_qdq_optimizer(x); + self.textual_projection = self.textual_projection.with_openvino_qdq_optimizer(x); + self + } + + pub fn with_openvino_num_threads_all(mut self, num_threads: usize) -> Self { + self.visual = self.visual.with_openvino_num_threads(num_threads); + self.textual = self.textual.with_openvino_num_threads(num_threads); + self.model = self.model.with_openvino_num_threads(num_threads); + self.encoder = self.encoder.with_openvino_num_threads(num_threads); + self.decoder = self.decoder.with_openvino_num_threads(num_threads); + self.visual_encoder = self.visual_encoder.with_openvino_num_threads(num_threads); + self.textual_encoder = self.textual_encoder.with_openvino_num_threads(num_threads); + self.visual_decoder = self.visual_decoder.with_openvino_num_threads(num_threads); + self.textual_decoder = self.textual_decoder.with_openvino_num_threads(num_threads); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_openvino_num_threads(num_threads); + self.size_encoder = self.size_encoder.with_openvino_num_threads(num_threads); + self.size_decoder = self.size_decoder.with_openvino_num_threads(num_threads); + self.coord_encoder = self.coord_encoder.with_openvino_num_threads(num_threads); + self.coord_decoder = self.coord_decoder.with_openvino_num_threads(num_threads); + self.visual_projection = self + .visual_projection + .with_openvino_num_threads(num_threads); + self.textual_projection = self + .textual_projection + .with_openvino_num_threads(num_threads); + self + } + + // onednn + pub fn with_onednn_arena_allocator_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_onednn_arena_allocator(x); + self.textual = self.textual.with_onednn_arena_allocator(x); + self.model = self.model.with_onednn_arena_allocator(x); + self.encoder = self.encoder.with_onednn_arena_allocator(x); + self.decoder = self.decoder.with_onednn_arena_allocator(x); + self.visual_encoder = self.visual_encoder.with_onednn_arena_allocator(x); + self.textual_encoder = self.textual_encoder.with_onednn_arena_allocator(x); + self.visual_decoder = self.visual_decoder.with_onednn_arena_allocator(x); + self.textual_decoder = self.textual_decoder.with_onednn_arena_allocator(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_onednn_arena_allocator(x); + self.size_encoder = self.size_encoder.with_onednn_arena_allocator(x); + self.size_decoder = self.size_decoder.with_onednn_arena_allocator(x); + self.coord_encoder = self.coord_encoder.with_onednn_arena_allocator(x); + self.coord_decoder = self.coord_decoder.with_onednn_arena_allocator(x); + self.visual_projection = self.visual_projection.with_onednn_arena_allocator(x); + self.textual_projection = self.textual_projection.with_onednn_arena_allocator(x); + self + } + + // tensorrt + pub fn with_tensorrt_fp16_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_tensorrt_fp16(x); + self.textual = self.textual.with_tensorrt_fp16(x); + self.model = self.model.with_tensorrt_fp16(x); + self.encoder = self.encoder.with_tensorrt_fp16(x); + self.decoder = self.decoder.with_tensorrt_fp16(x); + self.visual_encoder = self.visual_encoder.with_tensorrt_fp16(x); + self.textual_encoder = self.textual_encoder.with_tensorrt_fp16(x); + self.visual_decoder = self.visual_decoder.with_tensorrt_fp16(x); + self.textual_decoder = self.textual_decoder.with_tensorrt_fp16(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_fp16(x); + self.size_encoder = self.size_encoder.with_tensorrt_fp16(x); + self.size_decoder = self.size_decoder.with_tensorrt_fp16(x); + self.coord_encoder = self.coord_encoder.with_tensorrt_fp16(x); + self.coord_decoder = self.coord_decoder.with_tensorrt_fp16(x); + self.visual_projection = self.visual_projection.with_tensorrt_fp16(x); + self.textual_projection = self.textual_projection.with_tensorrt_fp16(x); + self + } + + pub fn with_tensorrt_engine_cache_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_tensorrt_engine_cache(x); + self.textual = self.textual.with_tensorrt_engine_cache(x); + self.model = self.model.with_tensorrt_engine_cache(x); + self.encoder = self.encoder.with_tensorrt_engine_cache(x); + self.decoder = self.decoder.with_tensorrt_engine_cache(x); + self.visual_encoder = self.visual_encoder.with_tensorrt_engine_cache(x); + self.textual_encoder = self.textual_encoder.with_tensorrt_engine_cache(x); + self.visual_decoder = self.visual_decoder.with_tensorrt_engine_cache(x); + self.textual_decoder = self.textual_decoder.with_tensorrt_engine_cache(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_engine_cache(x); + self.size_encoder = self.size_encoder.with_tensorrt_engine_cache(x); + self.size_decoder = self.size_decoder.with_tensorrt_engine_cache(x); + self.coord_encoder = self.coord_encoder.with_tensorrt_engine_cache(x); + self.coord_decoder = self.coord_decoder.with_tensorrt_engine_cache(x); + self.visual_projection = self.visual_projection.with_tensorrt_engine_cache(x); + self.textual_projection = self.textual_projection.with_tensorrt_engine_cache(x); + self + } + + pub fn with_tensorrt_timing_cache_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_tensorrt_timing_cache(x); + self.textual = self.textual.with_tensorrt_timing_cache(x); + self.model = self.model.with_tensorrt_timing_cache(x); + self.encoder = self.encoder.with_tensorrt_timing_cache(x); + self.decoder = self.decoder.with_tensorrt_timing_cache(x); + self.visual_encoder = self.visual_encoder.with_tensorrt_timing_cache(x); + self.textual_encoder = self.textual_encoder.with_tensorrt_timing_cache(x); + self.visual_decoder = self.visual_decoder.with_tensorrt_timing_cache(x); + self.textual_decoder = self.textual_decoder.with_tensorrt_timing_cache(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_timing_cache(x); + self.size_encoder = self.size_encoder.with_tensorrt_timing_cache(x); + self.size_decoder = self.size_decoder.with_tensorrt_timing_cache(x); + self.coord_encoder = self.coord_encoder.with_tensorrt_timing_cache(x); + self.coord_decoder = self.coord_decoder.with_tensorrt_timing_cache(x); + self.visual_projection = self.visual_projection.with_tensorrt_timing_cache(x); + self.textual_projection = self.textual_projection.with_tensorrt_timing_cache(x); + self + } + + // coreml + pub fn with_coreml_static_input_shapes_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_coreml_static_input_shapes(x); + self.textual = self.textual.with_coreml_static_input_shapes(x); + self.model = self.model.with_coreml_static_input_shapes(x); + self.encoder = self.encoder.with_coreml_static_input_shapes(x); + self.decoder = self.decoder.with_coreml_static_input_shapes(x); + self.visual_encoder = self.visual_encoder.with_coreml_static_input_shapes(x); + self.textual_encoder = self.textual_encoder.with_coreml_static_input_shapes(x); + self.visual_decoder = self.visual_decoder.with_coreml_static_input_shapes(x); + self.textual_decoder = self.textual_decoder.with_coreml_static_input_shapes(x); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_coreml_static_input_shapes(x); + self.size_encoder = self.size_encoder.with_coreml_static_input_shapes(x); + self.size_decoder = self.size_decoder.with_coreml_static_input_shapes(x); + self.coord_encoder = self.coord_encoder.with_coreml_static_input_shapes(x); + self.coord_decoder = self.coord_decoder.with_coreml_static_input_shapes(x); + self.visual_projection = self.visual_projection.with_coreml_static_input_shapes(x); + self.textual_projection = self.textual_projection.with_coreml_static_input_shapes(x); + self + } + + pub fn with_coreml_subgraph_running_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_coreml_subgraph_running(x); + self.textual = self.textual.with_coreml_subgraph_running(x); + self.model = self.model.with_coreml_subgraph_running(x); + self.encoder = self.encoder.with_coreml_subgraph_running(x); + self.decoder = self.decoder.with_coreml_subgraph_running(x); + self.visual_encoder = self.visual_encoder.with_coreml_subgraph_running(x); + self.textual_encoder = self.textual_encoder.with_coreml_subgraph_running(x); + self.visual_decoder = self.visual_decoder.with_coreml_subgraph_running(x); + self.textual_decoder = self.textual_decoder.with_coreml_subgraph_running(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_coreml_subgraph_running(x); + self.size_encoder = self.size_encoder.with_coreml_subgraph_running(x); + self.size_decoder = self.size_decoder.with_coreml_subgraph_running(x); + self.coord_encoder = self.coord_encoder.with_coreml_subgraph_running(x); + self.coord_decoder = self.coord_decoder.with_coreml_subgraph_running(x); + self.visual_projection = self.visual_projection.with_coreml_subgraph_running(x); + self.textual_projection = self.textual_projection.with_coreml_subgraph_running(x); + self + } + + // cann + pub fn with_cann_graph_inference_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_cann_graph_inference(x); + self.textual = self.textual.with_cann_graph_inference(x); + self.model = self.model.with_cann_graph_inference(x); + self.encoder = self.encoder.with_cann_graph_inference(x); + self.decoder = self.decoder.with_cann_graph_inference(x); + self.visual_encoder = self.visual_encoder.with_cann_graph_inference(x); + self.textual_encoder = self.textual_encoder.with_cann_graph_inference(x); + self.visual_decoder = self.visual_decoder.with_cann_graph_inference(x); + self.textual_decoder = self.textual_decoder.with_cann_graph_inference(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_cann_graph_inference(x); + self.size_encoder = self.size_encoder.with_cann_graph_inference(x); + self.size_decoder = self.size_decoder.with_cann_graph_inference(x); + self.coord_encoder = self.coord_encoder.with_cann_graph_inference(x); + self.coord_decoder = self.coord_decoder.with_cann_graph_inference(x); + self.visual_projection = self.visual_projection.with_cann_graph_inference(x); + self.textual_projection = self.textual_projection.with_cann_graph_inference(x); + self + } + + pub fn with_cann_dump_graphs_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_cann_dump_graphs(x); + self.textual = self.textual.with_cann_dump_graphs(x); + self.model = self.model.with_cann_dump_graphs(x); + self.encoder = self.encoder.with_cann_dump_graphs(x); + self.decoder = self.decoder.with_cann_dump_graphs(x); + self.visual_encoder = self.visual_encoder.with_cann_dump_graphs(x); + self.textual_encoder = self.textual_encoder.with_cann_dump_graphs(x); + self.visual_decoder = self.visual_decoder.with_cann_dump_graphs(x); + self.textual_decoder = self.textual_decoder.with_cann_dump_graphs(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_graphs(x); + self.size_encoder = self.size_encoder.with_cann_dump_graphs(x); + self.size_decoder = self.size_decoder.with_cann_dump_graphs(x); + self.coord_encoder = self.coord_encoder.with_cann_dump_graphs(x); + self.coord_decoder = self.coord_decoder.with_cann_dump_graphs(x); + self.visual_projection = self.visual_projection.with_cann_dump_graphs(x); + self.textual_projection = self.textual_projection.with_cann_dump_graphs(x); + self + } + + pub fn with_cann_dump_om_model_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_cann_dump_om_model(x); + self.textual = self.textual.with_cann_dump_om_model(x); + self.model = self.model.with_cann_dump_om_model(x); + self.encoder = self.encoder.with_cann_dump_om_model(x); + self.decoder = self.decoder.with_cann_dump_om_model(x); + self.visual_encoder = self.visual_encoder.with_cann_dump_om_model(x); + self.textual_encoder = self.textual_encoder.with_cann_dump_om_model(x); + self.visual_decoder = self.visual_decoder.with_cann_dump_om_model(x); + self.textual_decoder = self.textual_decoder.with_cann_dump_om_model(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_om_model(x); + self.size_encoder = self.size_encoder.with_cann_dump_om_model(x); + self.size_decoder = self.size_decoder.with_cann_dump_om_model(x); + self.coord_encoder = self.coord_encoder.with_cann_dump_om_model(x); + self.coord_decoder = self.coord_decoder.with_cann_dump_om_model(x); + self.visual_projection = self.visual_projection.with_cann_dump_om_model(x); + self.textual_projection = self.textual_projection.with_cann_dump_om_model(x); + self + } + + // nnapi + pub fn with_nnapi_cpu_only_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_nnapi_cpu_only(x); + self.textual = self.textual.with_nnapi_cpu_only(x); + self.model = self.model.with_nnapi_cpu_only(x); + self.encoder = self.encoder.with_nnapi_cpu_only(x); + self.decoder = self.decoder.with_nnapi_cpu_only(x); + self.visual_encoder = self.visual_encoder.with_nnapi_cpu_only(x); + self.textual_encoder = self.textual_encoder.with_nnapi_cpu_only(x); + self.visual_decoder = self.visual_decoder.with_nnapi_cpu_only(x); + self.textual_decoder = self.textual_decoder.with_nnapi_cpu_only(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_cpu_only(x); + self.size_encoder = self.size_encoder.with_nnapi_cpu_only(x); + self.size_decoder = self.size_decoder.with_nnapi_cpu_only(x); + self.coord_encoder = self.coord_encoder.with_nnapi_cpu_only(x); + self.coord_decoder = self.coord_decoder.with_nnapi_cpu_only(x); + self.visual_projection = self.visual_projection.with_nnapi_cpu_only(x); + self.textual_projection = self.textual_projection.with_nnapi_cpu_only(x); + self + } + + pub fn with_nnapi_disable_cpu_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_nnapi_disable_cpu(x); + self.textual = self.textual.with_nnapi_disable_cpu(x); + self.model = self.model.with_nnapi_disable_cpu(x); + self.encoder = self.encoder.with_nnapi_disable_cpu(x); + self.decoder = self.decoder.with_nnapi_disable_cpu(x); + self.visual_encoder = self.visual_encoder.with_nnapi_disable_cpu(x); + self.textual_encoder = self.textual_encoder.with_nnapi_disable_cpu(x); + self.visual_decoder = self.visual_decoder.with_nnapi_disable_cpu(x); + self.textual_decoder = self.textual_decoder.with_nnapi_disable_cpu(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_disable_cpu(x); + self.size_encoder = self.size_encoder.with_nnapi_disable_cpu(x); + self.size_decoder = self.size_decoder.with_nnapi_disable_cpu(x); + self.coord_encoder = self.coord_encoder.with_nnapi_disable_cpu(x); + self.coord_decoder = self.coord_decoder.with_nnapi_disable_cpu(x); + self.visual_projection = self.visual_projection.with_nnapi_disable_cpu(x); + self.textual_projection = self.textual_projection.with_nnapi_disable_cpu(x); + self + } + + pub fn with_nnapi_fp16_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_nnapi_fp16(x); + self.textual = self.textual.with_nnapi_fp16(x); + self.model = self.model.with_nnapi_fp16(x); + self.encoder = self.encoder.with_nnapi_fp16(x); + self.decoder = self.decoder.with_nnapi_fp16(x); + self.visual_encoder = self.visual_encoder.with_nnapi_fp16(x); + self.textual_encoder = self.textual_encoder.with_nnapi_fp16(x); + self.visual_decoder = self.visual_decoder.with_nnapi_fp16(x); + self.textual_decoder = self.textual_decoder.with_nnapi_fp16(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_fp16(x); + self.size_encoder = self.size_encoder.with_nnapi_fp16(x); + self.size_decoder = self.size_decoder.with_nnapi_fp16(x); + self.coord_encoder = self.coord_encoder.with_nnapi_fp16(x); + self.coord_decoder = self.coord_decoder.with_nnapi_fp16(x); + self.visual_projection = self.visual_projection.with_nnapi_fp16(x); + self.textual_projection = self.textual_projection.with_nnapi_fp16(x); + self + } + + pub fn with_nnapi_nchw_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_nnapi_nchw(x); + self.textual = self.textual.with_nnapi_nchw(x); + self.model = self.model.with_nnapi_nchw(x); + self.encoder = self.encoder.with_nnapi_nchw(x); + self.decoder = self.decoder.with_nnapi_nchw(x); + self.visual_encoder = self.visual_encoder.with_nnapi_nchw(x); + self.textual_encoder = self.textual_encoder.with_nnapi_nchw(x); + self.visual_decoder = self.visual_decoder.with_nnapi_nchw(x); + self.textual_decoder = self.textual_decoder.with_nnapi_nchw(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_nchw(x); + self.size_encoder = self.size_encoder.with_nnapi_nchw(x); + self.size_decoder = self.size_decoder.with_nnapi_nchw(x); + self.coord_encoder = self.coord_encoder.with_nnapi_nchw(x); + self.coord_decoder = self.coord_decoder.with_nnapi_nchw(x); + self.visual_projection = self.visual_projection.with_nnapi_nchw(x); + self.textual_projection = self.textual_projection.with_nnapi_nchw(x); + self + } + + // armnn + pub fn with_armnn_arena_allocator_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_armnn_arena_allocator(x); + self.textual = self.textual.with_armnn_arena_allocator(x); + self.model = self.model.with_armnn_arena_allocator(x); + self.encoder = self.encoder.with_armnn_arena_allocator(x); + self.decoder = self.decoder.with_armnn_arena_allocator(x); + self.visual_encoder = self.visual_encoder.with_armnn_arena_allocator(x); + self.textual_encoder = self.textual_encoder.with_armnn_arena_allocator(x); + self.visual_decoder = self.visual_decoder.with_armnn_arena_allocator(x); + self.textual_decoder = self.textual_decoder.with_armnn_arena_allocator(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_armnn_arena_allocator(x); + self.size_encoder = self.size_encoder.with_armnn_arena_allocator(x); + self.size_decoder = self.size_decoder.with_armnn_arena_allocator(x); + self.coord_encoder = self.coord_encoder.with_armnn_arena_allocator(x); + self.coord_decoder = self.coord_decoder.with_armnn_arena_allocator(x); + self.visual_projection = self.visual_projection.with_armnn_arena_allocator(x); + self.textual_projection = self.textual_projection.with_armnn_arena_allocator(x); + self + } + + // migraphx + pub fn with_migraphx_fp16_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_migraphx_fp16(x); + self.textual = self.textual.with_migraphx_fp16(x); + self.model = self.model.with_migraphx_fp16(x); + self.encoder = self.encoder.with_migraphx_fp16(x); + self.decoder = self.decoder.with_migraphx_fp16(x); + self.visual_encoder = self.visual_encoder.with_migraphx_fp16(x); + self.textual_encoder = self.textual_encoder.with_migraphx_fp16(x); + self.visual_decoder = self.visual_decoder.with_migraphx_fp16(x); + self.textual_decoder = self.textual_decoder.with_migraphx_fp16(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_fp16(x); + self.size_encoder = self.size_encoder.with_migraphx_fp16(x); + self.size_decoder = self.size_decoder.with_migraphx_fp16(x); + self.coord_encoder = self.coord_encoder.with_migraphx_fp16(x); + self.coord_decoder = self.coord_decoder.with_migraphx_fp16(x); + self.visual_projection = self.visual_projection.with_migraphx_fp16(x); + self.textual_projection = self.textual_projection.with_migraphx_fp16(x); + self + } + + pub fn with_migraphx_exhaustive_tune_all(mut self, x: bool) -> Self { + self.visual = self.visual.with_migraphx_exhaustive_tune(x); + self.textual = self.textual.with_migraphx_exhaustive_tune(x); + self.model = self.model.with_migraphx_exhaustive_tune(x); + self.encoder = self.encoder.with_migraphx_exhaustive_tune(x); + self.decoder = self.decoder.with_migraphx_exhaustive_tune(x); + self.visual_encoder = self.visual_encoder.with_migraphx_exhaustive_tune(x); + self.textual_encoder = self.textual_encoder.with_migraphx_exhaustive_tune(x); + self.visual_decoder = self.visual_decoder.with_migraphx_exhaustive_tune(x); + self.textual_decoder = self.textual_decoder.with_migraphx_exhaustive_tune(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_exhaustive_tune(x); + self.size_encoder = self.size_encoder.with_migraphx_exhaustive_tune(x); + self.size_decoder = self.size_decoder.with_migraphx_exhaustive_tune(x); + self.coord_encoder = self.coord_encoder.with_migraphx_exhaustive_tune(x); + self.coord_decoder = self.coord_decoder.with_migraphx_exhaustive_tune(x); + self.visual_projection = self.visual_projection.with_migraphx_exhaustive_tune(x); + self.textual_projection = self.textual_projection.with_migraphx_exhaustive_tune(x); + self + } } impl_ort_config_methods!(Config, model); diff --git a/src/utils/device.rs b/src/utils/device.rs index 97530de..5743193 100644 --- a/src/utils/device.rs +++ b/src/utils/device.rs @@ -3,7 +3,22 @@ pub enum Device { Cpu(usize), Cuda(usize), TensorRt(usize), - CoreMl(usize), + OpenVino(&'static str), + DirectMl(usize), + Cann(usize), + Rocm(usize), + Qnn(usize), + MiGraphX(usize), + CoreMl, + Xnnpack, + RkNpu, + OneDnn, + Acl, + NnApi, + ArmNn, + Tvm, + Vitis, + Azure, } impl Default for Device { @@ -18,41 +33,97 @@ impl std::fmt::Display for Device { Self::Cpu(i) => format!("CPU:{}", i), Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i), Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i), - Self::CoreMl(i) => format!("CoreML:{}(Apple)", i), + Self::Cann(i) => format!("CANN:{}(Huawei)", i), + Self::OpenVino(s) => format!("OpenVINO:{}(Intel)", s), + Self::DirectMl(i) => format!("DirectML:{}(Microsoft)", i), + Self::Qnn(i) => format!("QNN:{}(Qualcomm)", i), + Self::MiGraphX(i) => format!("MIGraphX:{}(AMD)", i), + Self::Rocm(i) => format!("ROCm:{}(AMD)", i), + Self::CoreMl => "CoreML(Apple)".to_string(), + Self::Azure => "Azure(Microsoft)".to_string(), + Self::Xnnpack => "XNNPACK".to_string(), + Self::OneDnn => "oneDNN(Intel)".to_string(), + Self::RkNpu => "RKNPU".to_string(), + Self::Acl => "ACL(Arm)".to_string(), + Self::NnApi => "NNAPI(Android)".to_string(), + Self::ArmNn => "ArmNN(Arm)".to_string(), + Self::Tvm => "TVM(Apache)".to_string(), + Self::Vitis => "VitisAI(AMD)".to_string(), }; write!(f, "{}", x) } } -impl TryFrom<&str> for Device { - type Error = anyhow::Error; +impl std::str::FromStr for Device { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { - // device and its id - let d_id: Vec<&str> = s.trim().split(':').collect(); - let (d, id) = match d_id.len() { - 1 => (d_id[0].trim(), 0), - 2 => (d_id[0].trim(), d_id[1].trim().parse::().unwrap_or(0)), - _ => anyhow::bail!( - "Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`" - ), - }; - // TODO: device-id checking - match d.to_lowercase().as_str() { - "cpu" => Ok(Self::Cpu(id)), - "cuda" => Ok(Self::Cuda(id)), - "trt" | "tensorrt" => Ok(Self::TensorRt(id)), - "coreml" | "mps" => Ok(Self::CoreMl(id)), + fn from_str(s: &str) -> Result { + #[inline] + fn parse_device_id(id_str: Option<&str>) -> usize { + id_str + .map(|s| s.trim().parse::().unwrap_or(0)) + .unwrap_or(0) + } + // Use split_once for better performance - no Vec allocation + let (device_type, id_part) = s + .trim() + .split_once(':') + .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id))); + + match device_type.to_lowercase().as_str() { + "cpu" => Ok(Self::Cpu(parse_device_id(id_part))), + "cuda" => Ok(Self::Cuda(parse_device_id(id_part))), + "trt" | "tensorrt" => Ok(Self::TensorRt(parse_device_id(id_part))), + "coreml" | "mps" => Ok(Self::CoreMl), + "openvino" => { + // For OpenVino, use the user input directly after first colon (trimmed) + let device_spec = id_part.map(|s| s.trim()).unwrap_or("CPU"); // Default to CPU if no specification provided + Ok(Self::OpenVino(Box::leak( + device_spec.to_string().into_boxed_str(), + ))) + } + "directml" => Ok(Self::DirectMl(parse_device_id(id_part))), + "xnnpack" => Ok(Self::Xnnpack), + "cann" => Ok(Self::Cann(parse_device_id(id_part))), + "rknpu" => Ok(Self::RkNpu), + "onednn" => Ok(Self::OneDnn), + "acl" => Ok(Self::Acl), + "rocm" => Ok(Self::Rocm(parse_device_id(id_part))), + "nnapi" => Ok(Self::NnApi), + "armnn" => Ok(Self::ArmNn), + "tvm" => Ok(Self::Tvm), + "qnn" => Ok(Self::Qnn(parse_device_id(id_part))), + "migraphx" => Ok(Self::MiGraphX(parse_device_id(id_part))), + "vitisai" => Ok(Self::Vitis), + "azure" => Ok(Self::Azure), _ => anyhow::bail!("Unsupported device str: {s:?}."), } } } impl Device { - pub fn id(&self) -> usize { + pub fn id(&self) -> Option { match self { - Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i, + Self::Cpu(i) + | Self::Cuda(i) + | Self::TensorRt(i) + | Self::Cann(i) + | Self::Qnn(i) + | Self::Rocm(i) + | Self::MiGraphX(i) + | Self::DirectMl(i) => Some(*i), + Self::OpenVino(_) + | Self::Xnnpack + | Self::CoreMl + | Self::RkNpu + | Self::OneDnn + | Self::NnApi + | Self::Azure + | Self::Vitis + | Self::ArmNn + | Self::Tvm + | Self::Acl => None, } } } diff --git a/src/utils/dtype.rs b/src/utils/dtype.rs index 867ddd7..ef73223 100644 --- a/src/utils/dtype.rs +++ b/src/utils/dtype.rs @@ -29,10 +29,10 @@ pub enum DType { Complex128, } -impl TryFrom<&str> for DType { - type Error = anyhow::Error; +impl std::str::FromStr for DType { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "auto" | "dyn" => Ok(Self::Auto), "u4" | "uint4" => Ok(Self::Uint4), diff --git a/src/utils/ort_config.rs b/src/utils/ort_config.rs index 0a7916d..cc4a08b 100644 --- a/src/utils/ort_config.rs +++ b/src/utils/ort_config.rs @@ -9,10 +9,42 @@ pub struct ORTConfig { pub device: Device, pub iiixs: Vec, pub num_dry_run: usize, - pub trt_fp16: bool, - pub graph_opt_level: Option, pub spec: String, // TODO: move out pub dtype: DType, // For dynamically loading the model + // global + pub graph_opt_level: Option, + pub num_intra_threads: Option, + pub num_inter_threads: Option, + // cpu + pub cpu_arena_allocator: bool, + // openvino + pub openvino_dynamic_shapes: bool, + pub openvino_opencl_throttling: bool, + pub openvino_qdq_optimizer: bool, + pub openvino_num_threads: Option, + // onednn + pub onednn_arena_allocator: bool, + // tensorrt + pub tensorrt_fp16: bool, + pub tensorrt_engine_cache: bool, + pub tensorrt_timing_cache: bool, + // coreml + pub coreml_static_input_shapes: bool, + pub coreml_subgraph_running: bool, + // cann + pub cann_graph_inference: bool, + pub cann_dump_graphs: bool, + pub cann_dump_om_model: bool, + // nnapi + pub nnapi_cpu_only: bool, + pub nnapi_disable_cpu: bool, + pub nnapi_fp16: bool, + pub nnapi_nchw: bool, + // armnn + pub armnn_arena_allocator: bool, + // migraphx + pub migraphx_fp16: bool, + pub migraphx_exhaustive_tune: bool, } impl Default for ORTConfig { @@ -21,11 +53,33 @@ impl Default for ORTConfig { file: Default::default(), device: Default::default(), iiixs: Default::default(), - graph_opt_level: Default::default(), spec: Default::default(), dtype: Default::default(), num_dry_run: 3, - trt_fp16: true, + graph_opt_level: Default::default(), + num_intra_threads: None, + num_inter_threads: None, + cpu_arena_allocator: true, + openvino_dynamic_shapes: true, + openvino_opencl_throttling: true, + openvino_qdq_optimizer: true, + openvino_num_threads: None, + coreml_static_input_shapes: false, + coreml_subgraph_running: true, + tensorrt_fp16: true, + tensorrt_engine_cache: true, + tensorrt_timing_cache: false, + cann_graph_inference: true, + cann_dump_graphs: false, + cann_dump_om_model: false, + onednn_arena_allocator: true, + nnapi_cpu_only: false, + nnapi_disable_cpu: false, + nnapi_fp16: true, + nnapi_nchw: false, + armnn_arena_allocator: true, + migraphx_fp16: true, + migraphx_exhaustive_tune: false, } } } @@ -122,10 +176,6 @@ macro_rules! impl_ort_config_methods { self.$field = self.$field.with_device(device); self } - pub fn [](mut self, x: bool) -> Self { - self.$field = self.$field.with_trt_fp16(x); - self - } pub fn [](mut self, x: usize) -> Self { self.$field = self.$field.with_num_dry_run(x); self @@ -134,6 +184,113 @@ macro_rules! impl_ort_config_methods { self.$field = self.$field.with_ixx(i, ii, x); self } + // global + pub fn [](mut self, x: u8) -> Self { + self.$field = self.$field.with_graph_opt_level(x); + self + } + pub fn [](mut self, x: usize) -> Self { + self.$field = self.$field.with_num_intra_threads(x); + self + } + pub fn [](mut self, x: usize) -> Self { + self.$field = self.$field.with_num_inter_threads(x); + self + } + // cpu + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_cpu_arena_allocator(x); + self + } + // openvino + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_openvino_dynamic_shapes(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_openvino_opencl_throttling(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_openvino_qdq_optimizer(x); + self + } + pub fn [](mut self, x: usize) -> Self { + self.$field = self.$field.with_openvino_num_threads(x); + self + } + // onednn + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_onednn_arena_allocator(x); + self + } + + // tensorrt + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_tensorrt_fp16(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_tensorrt_engine_cache(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_tensorrt_timing_cache(x); + self + } + // coreml + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_coreml_static_input_shapes(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_coreml_subgraph_running(x); + self + } + // cann + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_cann_graph_inference(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_cann_dump_graphs(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_cann_dump_om_model(x); + self + } + // nnapi + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_nnapi_cpu_only(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_nnapi_disable_cpu(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_nnapi_fp16(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_nnapi_nchw(x); + self + } + // armnn + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_armnn_arena_allocator(x); + self + } + // migraphx + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_migraphx_fp16(x); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_migraphx_exhaustive_tune(x); + self + } } } }; diff --git a/src/utils/scale.rs b/src/utils/scale.rs index c770317..0949bbc 100644 --- a/src/utils/scale.rs +++ b/src/utils/scale.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum Scale { N, @@ -64,10 +66,10 @@ impl TryFrom for Scale { } } -impl TryFrom<&str> for Scale { - type Error = anyhow::Error; +impl FromStr for Scale { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { "n" | "nano" => Ok(Self::N), "t" | "tiny" => Ok(Self::T), diff --git a/src/utils/task.rs b/src/utils/task.rs index c10dece..af788e7 100644 --- a/src/utils/task.rs +++ b/src/utils/task.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + #[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)] pub enum Task { /// Image classification task. @@ -164,10 +166,10 @@ impl std::fmt::Display for Task { } } -impl TryFrom<&str> for Task { - type Error = anyhow::Error; +impl FromStr for Task { + type Err = anyhow::Error; - fn try_from(s: &str) -> Result { + fn from_str(s: &str) -> Result { // TODO match s.to_lowercase().as_str() { "cls" | "classify" | "classification" => Ok(Self::ImageClassification), diff --git a/src/viz/color.rs b/src/viz/color.rs index e981b43..6319ce6 100644 --- a/src/viz/color.rs +++ b/src/viz/color.rs @@ -61,20 +61,24 @@ impl From for [u8; 3] { } } -impl TryFrom<&str> for Color { - type Error = &'static str; +impl std::str::FromStr for Color { + type Err = anyhow::Error; - fn try_from(x: &str) -> Result { + fn from_str(x: &str) -> Result { let hex = x.trim_start_matches('#'); let hex = match hex.len() { 6 => format!("{}ff", hex), 8 => hex.to_string(), - _ => return Err("Failed to convert `Color` from str: invalid length"), + _ => { + return Err(anyhow::anyhow!( + "Failed to convert `Color` from str: invalid length" + )) + } }; u32::from_str_radix(&hex, 16) .map(Self) - .map_err(|_| "Failed to convert `Color` from str: invalid hex") + .map_err(|_| anyhow::anyhow!("Failed to convert `Color` from str: invalid hex")) } } @@ -151,17 +155,8 @@ impl Color { xs.iter().copied().map(Into::into).collect() } - pub fn try_create_palette + Copy>(xs: &[A]) -> Result> - where - >::Error: std::fmt::Debug, - { - xs.iter() - .copied() - .map(|x| { - x.try_into() - .map_err(|e| anyhow::anyhow!("Failed to convert: {:?}", e)) - }) - .collect() + pub fn try_create_palette(xs: &[&str]) -> Result> { + xs.iter().map(|x| x.parse()).collect() } pub fn palette_rand(n: usize) -> Vec { diff --git a/src/viz/colormap256.rs b/src/viz/colormap256.rs index 56b3f98..2ba4b44 100644 --- a/src/viz/colormap256.rs +++ b/src/viz/colormap256.rs @@ -14,20 +14,22 @@ pub enum ColorMap256 { SmoothCoolWarm, } -impl From<&str> for ColorMap256 { - fn from(s: &str) -> Self { +impl std::str::FromStr for ColorMap256 { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { - "turbo" => Self::Turbo, - "inferno" => Self::Inferno, - "plasma" => Self::Plasma, - "viridis" => Self::Viridis, - "magma" => Self::Magma, - "bentcoolwarm" => Self::BentCoolWarm, - "blackbody" => Self::BlackBody, - "extendedkindlmann" => Self::ExtendedKindLmann, - "kindlmann" => Self::KindLmann, - "smoothcoolwarm" => Self::SmoothCoolWarm, - s => unimplemented!("{} is not supported for now!", s), + "turbo" => Ok(Self::Turbo), + "inferno" => Ok(Self::Inferno), + "plasma" => Ok(Self::Plasma), + "viridis" => Ok(Self::Viridis), + "magma" => Ok(Self::Magma), + "bentcoolwarm" => Ok(Self::BentCoolWarm), + "blackbody" => Ok(Self::BlackBody), + "extendedkindlmann" => Ok(Self::ExtendedKindLmann), + "kindlmann" => Ok(Self::KindLmann), + "smoothcoolwarm" => Ok(Self::SmoothCoolWarm), + _ => Err(anyhow::anyhow!("Unsupported colormap: {}", s)), } } }