Add some eps (#108)

This commit is contained in:
Jamjamjon
2025-06-05 16:29:29 +08:00
committed by GitHub
parent cb587cd57c
commit 0e8d4f832a
41 changed files with 1360 additions and 159 deletions

View File

@ -1,7 +1,7 @@
[package]
name = "usls"
edition = "2021"
version = "0.1.0-beta.3"
version = "0.1.0-beta.4"
rust-version = "1.82"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@ -45,6 +45,7 @@ ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, fea
] }
tokenizers = { version = "0.21.1" }
paste = "1.0.15"
base64ct = "=1.7.3"
[build-dependencies]
prost-build = "0.13.5"
@ -53,11 +54,27 @@ prost-build = "0.13.5"
argh = "0.1.13"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
[features]
default = [ "ort-download-binaries" ]
video = [ "dep:video-rs" ]
ort-download-binaries = [ "ort", "ort/download-binaries" ]
ort-load-dynamic = [ "ort", "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
trt = [ "ort/tensorrt" ]
tensorrt = [ "ort/tensorrt" ]
coreml = [ "ort/coreml" ]
openvino = [ "ort/openvino" ]
onednn = [ "ort/onednn" ]
directml = [ "ort/directml" ]
xnnpack = [ "ort/xnnpack" ]
cann = [ "ort/cann" ]
rknpu = [ "ort/rknpu" ]
acl = [ "ort/acl" ]
rocm = [ "ort/rocm" ]
nnapi = [ "ort/nnapi" ]
armnn = [ "ort/armnn" ]
tvm = [ "ort/tvm" ]
qnn = [ "ort/qnn" ]
migraphx = [ "ort/migraphx" ]
vitis = [ "ort/vitis" ]
azure = [ "ort/azure" ]

View File

@ -21,8 +21,8 @@ fn main() -> anyhow::Result<()> {
// build model
let config = Config::ben2_base()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = RMBG::new(config)?;

View File

@ -21,7 +21,7 @@ fn main() -> anyhow::Result<()> {
// build model
let config = Config::blip_v1_base_caption()
.with_device_all(args.device.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = Blip::new(config)?;

View File

@ -46,8 +46,8 @@ fn main() -> anyhow::Result<()> {
};
let config = config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = ImageClassifier::try_from(config)?;

View File

@ -29,8 +29,8 @@ fn main() -> Result<()> {
// clip_vit_b32()
// jina_clip_v1()
// jina_clip_v2()
.with_dtype_all(args.dtype.as_str().try_into()?)
.with_device_all(args.device.as_str().try_into()?)
.with_dtype_all(args.dtype.parse()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = Clip::new(config)?;

View File

@ -47,9 +47,9 @@ fn main() -> Result<()> {
// build model
let config = match &args.model {
Some(m) => Config::db().with_model_file(m),
None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.as_str().try_into()?),
None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.parse()?),
}
.with_device_all(args.device.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = DB::new(config)?;

View File

@ -18,7 +18,7 @@ fn main() -> Result<()> {
// annotate
let annotator =
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into()));
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?));
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",

View File

@ -24,8 +24,8 @@ fn main() -> Result<()> {
// model
let config = Config::depth_pro()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = DepthPro::new(config)?;
@ -38,7 +38,7 @@ fn main() -> Result<()> {
// annotate
let annotator =
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into()));
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?));
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",

View File

@ -19,7 +19,7 @@ fn main() -> Result<()> {
// build model
let config = Config::doclayout_yolo_docstructbench()
.with_model_device(args.device.as_str().try_into()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = YOLO::new(config)?;

View File

@ -26,7 +26,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
let config = match args.scale.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::T => Config::fast_tiny(),
Scale::S => Config::fast_small(),
Scale::B => Config::fast_base(),
@ -34,8 +34,8 @@ fn main() -> Result<()> {
};
let mut model = DB::new(
config
.with_dtype_all(args.dtype.as_str().try_into()?)
.with_device_all(args.device.as_str().try_into()?)
.with_dtype_all(args.dtype.parse()?)
.with_device_all(args.device.parse()?)
.commit()?,
)?;

View File

@ -23,8 +23,8 @@ fn main() -> Result<()> {
// build model
let config = Config::fastsam_s()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = YOLO::new(config)?;

View File

@ -26,8 +26,8 @@ fn main() -> Result<()> {
// build model
let config = Config::florence2_base()
.with_dtype_all(args.dtype.as_str().try_into()?)
.with_device_all(args.device.as_str().try_into()?)
.with_dtype_all(args.dtype.parse()?)
.with_device_all(args.device.parse()?)
.with_batch_size_all(xs.len())
.commit()?;
let mut model = Florence2::new(config)?;

View File

@ -46,8 +46,8 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
let config = Config::grounding_dino_tiny()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
.with_class_confs(&[0.25])
.with_text_confs(&[0.25])

View File

@ -27,7 +27,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
let config = match args.scale.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::T => Config::linknet_r18(),
Scale::S => Config::linknet_r34(),
Scale::B => Config::linknet_r50(),
@ -35,8 +35,8 @@ fn main() -> Result<()> {
};
let mut model = DB::new(
config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?,
)?;

View File

@ -39,13 +39,13 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
let config = match args.scale.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::Billion(0.5) => Config::moondream2_0_5b(),
Scale::Billion(2.) => Config::moondream2_2b(),
_ => unimplemented!(),
}
.with_dtype_all(args.dtype.as_str().try_into()?)
.with_device_all(args.device.as_str().try_into()?)
.with_dtype_all(args.dtype.parse()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = Moondream2::new(config)?;
@ -54,7 +54,7 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&args.source)?;
// run with task
let task: Task = args.task.as_str().try_into()?;
let task: Task = args.task.parse()?;
let ys = model.forward(&xs, &task)?;
// annotate

View File

@ -49,8 +49,8 @@ fn main() -> Result<()> {
// config
let config = Config::owlv2_base_ensemble()
// owlv2_base()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
.commit()?;
let mut model = OWLv2::new(config)?;

View File

@ -31,8 +31,8 @@ fn main() -> anyhow::Result<()> {
// build model
let config = config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = RMBG::new(config)?;

View File

@ -28,9 +28,9 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// Build model
let config = match args.kind.as_str().try_into()? {
let config = match args.kind.parse()? {
SamKind::Sam => Config::sam_v1_base(),
SamKind::Sam2 => match args.scale.as_str().try_into()? {
SamKind::Sam2 => match args.scale.parse()? {
Scale::T => Config::sam2_tiny(),
Scale::S => Config::sam2_small(),
Scale::B => Config::sam2_base_plus(),
@ -40,7 +40,7 @@ fn main() -> Result<()> {
SamKind::SamHq => Config::sam_hq_tiny(),
SamKind::EdgeSam => Config::edge_sam_3x(),
}
.with_device_all(args.device.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = SAM::new(config)?;

View File

@ -25,14 +25,14 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// Build model
let config = match args.scale.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::T => Config::sam2_1_tiny(),
Scale::S => Config::sam2_1_small(),
Scale::B => Config::sam2_1_base_plus(),
Scale::L => Config::sam2_1_large(),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
}
.with_device_all(args.device.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = SAM2::new(config)?;

View File

@ -18,7 +18,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build
let config = Config::sapiens_seg_0_3b()
.with_model_device(args.device.as_str().try_into()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = Sapiens::new(config)?;

View File

@ -27,8 +27,8 @@ fn main() -> Result<()> {
// build model
let config = Config::slanet_lcnet_v2_mobile_ch()
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.parse()?)
.with_model_dtype(args.dtype.parse()?)
.commit()?;
let mut model = SLANet::new(config)?;

View File

@ -29,12 +29,12 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
let config = match args.scale.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::Million(256.) => Config::smolvlm_256m(),
Scale::Million(500.) => Config::smolvlm_500m(),
_ => unimplemented!(),
}
.with_device_all(args.device.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.commit()?;
let mut model = SmolVLM::new(config)?;

View File

@ -32,8 +32,8 @@ fn main() -> Result<()> {
// ppocr_rec_v4_en()
// repsvtr_ch()
.with_model_ixx(0, 3, args.max_text_length.into())
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.parse()?)
.with_model_dtype(args.dtype.parse()?)
.commit()?;
let mut model = SVTR::new(config)?;

View File

@ -38,19 +38,19 @@ fn main() -> anyhow::Result<()> {
])?;
// build model
let config = match args.scale.as_str().try_into()? {
Scale::S => match args.kind.as_str().try_into()? {
let config = match args.scale.parse()? {
Scale::S => match args.kind.parse()? {
TrOCRKind::Printed => Config::trocr_small_printed(),
TrOCRKind::HandWritten => Config::trocr_small_handwritten(),
},
Scale::B => match args.kind.as_str().try_into()? {
Scale::B => match args.kind.parse()? {
TrOCRKind::Printed => Config::trocr_base_printed(),
TrOCRKind::HandWritten => Config::trocr_base_handwritten(),
},
x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x),
}
.with_device_all(args.device.as_str().try_into()?)
.with_dtype_all(args.dtype.as_str().try_into()?)
.with_device_all(args.device.parse()?)
.with_dtype_all(args.dtype.parse()?)
.commit()?;
let mut model = TrOCR::new(config)?;

View File

@ -23,8 +23,8 @@ fn main() -> Result<()> {
// build model
let config = Config::ultralytics_rtdetr_l()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut model = YOLO::new(config)?;

View File

@ -27,7 +27,7 @@ fn main() -> Result<()> {
let options_yolo = Config::yolo_detect()
.with_scale(Scale::N)
.with_version(8.into())
.with_model_device(args.device.as_str().try_into()?)
.with_model_device(args.device.parse()?)
.commit()?;
let mut yolo = YOLO::new(options_yolo)?;

View File

@ -132,12 +132,12 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
let mut config = Config::yolo()
.with_model_file(&args.model.unwrap_or_default())
.with_task(args.task.as_str().try_into()?)
.with_task(args.task.parse()?)
.with_version(args.ver.try_into()?)
.with_scale(args.scale.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_trt_fp16(args.trt_fp16)
.with_scale(args.scale.parse()?)
.with_model_dtype(args.dtype.parse()?)
.with_model_device(args.device.parse()?)
.with_model_tensorrt_fp16(args.trt_fp16)
.with_model_ixx(
0,
0,

View File

@ -28,8 +28,8 @@ fn main() -> Result<()> {
// yoloe_11s_seg_pf()
// yoloe_11m_seg_pf()
// yoloe_11l_seg_pf()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().parse()?)
.with_model_device(args.device.as_str().parse()?)
.commit()?;
let mut model = YOLO::new(config)?;

View File

@ -66,7 +66,6 @@ pub struct Engine {
pub file: String,
pub spec: String,
pub device: Device,
pub trt_fp16: bool,
#[args(inc)]
pub iiixs: Vec<Iiix>,
#[args(aka = "parameters")]
@ -77,7 +76,50 @@ pub struct Engine {
pub onnx: Option<OnnxIo>,
pub ts: Ts,
pub num_dry_run: usize,
// global
pub graph_opt_level: Option<u8>,
pub num_intra_threads: Option<usize>,
pub num_inter_threads: Option<usize>,
// cpu
pub cpu_arena_allocator: bool,
// tensorrt
pub tensorrt_fp16: bool,
pub tensorrt_engine_cache: bool,
pub tensorrt_timing_cache: bool,
// openvino
pub openvino_dynamic_shapes: bool,
pub openvino_opencl_throttling: bool,
pub openvino_qdq_optimizer: bool,
pub openvino_num_threads: Option<usize>,
// onednn
pub onednn_arena_allocator: bool,
// coreml
pub coreml_static_input_shapes: bool,
pub coreml_subgraph_running: bool,
// cann
pub cann_graph_inference: bool,
pub cann_dump_graphs: bool,
pub cann_dump_om_model: bool,
// nnapi
pub nnapi_cpu_only: bool,
pub nnapi_disable_cpu: bool,
pub nnapi_fp16: bool,
pub nnapi_nchw: bool,
// armnn
pub armnn_arena_allocator: bool,
// migraphx
pub migraphx_fp16: bool,
pub migraphx_exhaustive_tune: bool,
}
impl Default for Engine {
@ -85,7 +127,6 @@ impl Default for Engine {
Self {
file: Default::default(),
device: Device::Cpu(0),
trt_fp16: false,
spec: Default::default(),
iiixs: Default::default(),
num_dry_run: 3,
@ -94,7 +135,40 @@ impl Default for Engine {
inputs_minoptmax: vec![],
onnx: None,
ts: Ts::default(),
// global
graph_opt_level: None,
num_intra_threads: None,
num_inter_threads: None,
// cpu
cpu_arena_allocator: true,
// openvino
openvino_dynamic_shapes: true,
openvino_opencl_throttling: true,
openvino_qdq_optimizer: true,
openvino_num_threads: None,
// onednn
onednn_arena_allocator: true,
// coreml
coreml_static_input_shapes: false,
coreml_subgraph_running: true,
// tensorrt
tensorrt_fp16: true,
tensorrt_engine_cache: true,
tensorrt_timing_cache: false,
// cann
cann_graph_inference: true,
cann_dump_graphs: false,
cann_dump_om_model: false,
// nnapi
nnapi_cpu_only: false,
nnapi_disable_cpu: false,
nnapi_fp16: true,
nnapi_nchw: false,
// armnn
armnn_arena_allocator: true,
// migraphx
migraphx_fp16: true,
migraphx_exhaustive_tune: false,
}
}
}
@ -106,9 +180,40 @@ impl Engine {
spec: config.spec.clone(),
iiixs: config.iiixs.clone(),
device: config.device,
trt_fp16: config.trt_fp16,
num_dry_run: config.num_dry_run,
// global
graph_opt_level: config.graph_opt_level,
num_intra_threads: config.num_intra_threads,
num_inter_threads: config.num_inter_threads,
// cpu
cpu_arena_allocator: config.cpu_arena_allocator,
// openvino
openvino_dynamic_shapes: config.openvino_dynamic_shapes,
openvino_opencl_throttling: config.openvino_opencl_throttling,
openvino_qdq_optimizer: config.openvino_qdq_optimizer,
openvino_num_threads: config.openvino_num_threads,
// coreml
coreml_static_input_shapes: config.coreml_static_input_shapes,
coreml_subgraph_running: config.coreml_subgraph_running,
// tensorrt
tensorrt_fp16: config.tensorrt_fp16,
tensorrt_engine_cache: config.tensorrt_engine_cache,
tensorrt_timing_cache: config.tensorrt_timing_cache,
// cann
cann_graph_inference: config.cann_graph_inference,
cann_dump_graphs: config.cann_dump_graphs,
cann_dump_om_model: config.cann_dump_om_model,
// nnapi
nnapi_cpu_only: config.nnapi_cpu_only,
nnapi_disable_cpu: config.nnapi_disable_cpu,
nnapi_fp16: config.nnapi_fp16,
nnapi_nchw: config.nnapi_nchw,
// armnn
armnn_arena_allocator: config.armnn_arena_allocator,
// migraphx
migraphx_fp16: config.migraphx_fp16,
migraphx_exhaustive_tune: config.migraphx_exhaustive_tune,
..Default::default()
}
.build()
@ -338,17 +443,20 @@ impl Engine {
let compile_help = "Please compile ONNXRuntime with #EP";
let feature_help = "#EP EP requires the features: `#FEATURE`. \
\nConsider enabling them by passing, e.g., `--features #FEATURE`";
let n_threads_available = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
match self.device {
Device::TensorRt(id) => {
#[cfg(not(feature = "trt"))]
#[cfg(not(feature = "tensorrt"))]
{
anyhow::bail!(feature_help
.replace("#EP", "TensorRT")
.replace("#FEATURE", "trt"));
.replace("#FEATURE", "tensorrt"));
}
#[cfg(feature = "trt")]
#[cfg(feature = "tensorrt")]
{
// generate shapes
let mut spec_min = String::new();
@ -379,13 +487,16 @@ impl Engine {
spec_max += &s_max;
}
let p = crate::Dir::Cache.crate_dir_default_with_subs(&["trt-cache"])?;
let ep = ort::execution_providers::TensorRTExecutionProvider::default()
.with_device_id(id as i32)
.with_fp16(self.trt_fp16)
.with_engine_cache(true)
.with_engine_cache_path(p.to_str().unwrap())
.with_timing_cache(false)
.with_fp16(self.tensorrt_fp16)
.with_engine_cache(self.tensorrt_engine_cache)
.with_timing_cache(self.tensorrt_timing_cache)
.with_engine_cache_path(
crate::Dir::Cache
.crate_dir_default_with_subs(&["caches", "tensorrt"])?
.display(),
)
.with_profile_min_shapes(spec_min)
.with_profile_opt_shapes(spec_opt)
.with_profile_max_shapes(spec_max);
@ -427,7 +538,7 @@ impl Engine {
}
}
}
Device::CoreMl(id) => {
Device::CoreMl => {
#[cfg(not(feature = "coreml"))]
{
anyhow::bail!(feature_help
@ -439,12 +550,12 @@ impl Engine {
let ep = ort::execution_providers::CoreMLExecutionProvider::default()
.with_model_cache_dir(
crate::Dir::Cache
.crate_dir_default_with_subs(&["coreml-cache"])?
.crate_dir_default_with_subs(&["caches", "coreml"])?
.display(),
)
.with_static_input_shapes(self.coreml_static_input_shapes)
.with_subgraphs(self.coreml_subgraph_running)
.with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All)
.with_static_input_shapes(false)
.with_subgraphs(true)
.with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram)
.with_specialization_strategy(
ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction,
@ -459,9 +570,345 @@ impl Engine {
}
}
}
Device::OpenVino(dt) => {
#[cfg(not(feature = "openvino"))]
{
anyhow::bail!(feature_help
.replace("#EP", "OpenVINO")
.replace("#FEATURE", "openvino"));
}
#[cfg(feature = "openvino")]
{
let ep = ort::execution_providers::OpenVINOExecutionProvider::default()
.with_device_type(dt)
.with_num_threads(self.openvino_num_threads.unwrap_or(n_threads_available))
.with_dynamic_shapes(self.openvino_dynamic_shapes)
.with_opencl_throttling(self.openvino_opencl_throttling)
.with_qdq_optimizer(self.openvino_qdq_optimizer)
.with_cache_dir(
crate::Dir::Cache
.crate_dir_default_with_subs(&["caches", "openvino"])?
.display()
.to_string(),
);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register OpenVINO: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "OpenVINO")),
}
}
}
Device::DirectMl(id) => {
#[cfg(not(feature = "directml"))]
{
anyhow::bail!(feature_help
.replace("#EP", "DirectML")
.replace("#FEATURE", "directml"));
}
#[cfg(feature = "directml")]
{
let ep = ort::execution_providers::DirectMLExecutionProvider::default()
.with_device_id(id as i32);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register DirectML: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "DirectML")),
}
}
}
Device::Xnnpack => {
#[cfg(not(feature = "xnnpack"))]
{
anyhow::bail!(feature_help
.replace("#EP", "XNNPack")
.replace("#FEATURE", "xnnpack"));
}
#[cfg(feature = "xnnpack")]
{
let ep = ort::execution_providers::XNNPACKExecutionProvider::default();
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register XNNPack: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "XNNPack")),
}
}
}
Device::Cann(id) => {
#[cfg(not(feature = "cann"))]
{
anyhow::bail!(feature_help
.replace("#EP", "CANN")
.replace("#FEATURE", "cann"));
}
#[cfg(feature = "cann")]
{
let ep = ort::execution_providers::CANNExecutionProvider::default()
.with_device_id(id as i32)
.with_cann_graph(self.cann_graph_inference)
.with_dump_graphs(self.cann_dump_graphs)
.with_dump_om_model(self.cann_dump_om_model);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register CANN: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "CANN")),
}
}
}
Device::RkNpu => {
#[cfg(not(feature = "rknpu"))]
{
anyhow::bail!(feature_help
.replace("#EP", "RKNPU")
.replace("#FEATURE", "rknpu"));
}
#[cfg(feature = "rknpu")]
{
let ep = ort::execution_providers::RKNPUExecutionProvider::default();
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register RKNPU: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "RKNPU")),
}
}
}
Device::OneDnn => {
#[cfg(not(feature = "onednn"))]
{
anyhow::bail!(feature_help
.replace("#EP", "oneDNN")
.replace("#FEATURE", "onednn"));
}
#[cfg(feature = "onednn")]
{
let ep = ort::execution_providers::OneDNNExecutionProvider::default()
.with_arena_allocator(self.onednn_arena_allocator);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register oneDNN: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "oneDNN")),
}
}
}
Device::Acl => {
#[cfg(not(feature = "acl"))]
{
anyhow::bail!(feature_help
.replace("#EP", "ArmACL")
.replace("#FEATURE", "acl"));
}
#[cfg(feature = "acl")]
{
let ep = ort::execution_providers::ACLExecutionProvider::default()
.with_fast_math(true);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register ArmACL: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "ArmACL")),
}
}
}
Device::Rocm(id) => {
#[cfg(not(feature = "rocm"))]
{
anyhow::bail!(feature_help
.replace("#EP", "ROCm")
.replace("#FEATURE", "rocm"));
}
#[cfg(feature = "rocm")]
{
let ep = ort::execution_providers::ROCmExecutionProvider::default()
.with_device_id(id as _);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register ROCm: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "ROCm")),
}
}
}
Device::NnApi => {
#[cfg(not(feature = "nnapi"))]
{
anyhow::bail!(feature_help
.replace("#EP", "NNAPI")
.replace("#FEATURE", "nnapi"));
}
#[cfg(feature = "nnapi")]
{
let ep = ort::execution_providers::NNAPIExecutionProvider::default()
.with_fp16(self.nnapi_fp16)
.with_nchw(self.nnapi_nchw)
.with_cpu_only(self.nnapi_cpu_only)
.with_disable_cpu(self.nnapi_disable_cpu);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register NNAPI: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "NNAPI")),
}
}
}
Device::ArmNn => {
#[cfg(not(feature = "armnn"))]
{
anyhow::bail!(feature_help
.replace("#EP", "ArmNN")
.replace("#FEATURE", "armnn"));
}
#[cfg(feature = "armnn")]
{
let ep = ort::execution_providers::ArmNNExecutionProvider::default()
.with_arena_allocator(self.armnn_arena_allocator);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register ArmNN: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "ArmNN")),
}
}
}
Device::Tvm => {
#[cfg(not(feature = "tvm"))]
{
anyhow::bail!(feature_help
.replace("#EP", "TVM")
.replace("#FEATURE", "tvm"));
}
#[cfg(feature = "tvm")]
{
let ep = ort::execution_providers::TVMExecutionProvider::default();
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register TVM: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "TVM")),
}
}
}
Device::Qnn(id) => {
#[cfg(not(feature = "qnn"))]
{
anyhow::bail!(feature_help
.replace("#EP", "QNN")
.replace("#FEATURE", "qnn"));
}
#[cfg(feature = "qnn")]
{
let ep = ort::execution_providers::QNNExecutionProvider::default()
.with_device_id(id as _);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register QNN: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "QNN")),
}
}
}
Device::MiGraphX(id) => {
#[cfg(not(feature = "migraphx"))]
{
anyhow::bail!(feature_help
.replace("#EP", "MIGraphX")
.replace("#FEATURE", "migraphx"));
}
#[cfg(feature = "migraphx")]
{
let ep = ort::execution_providers::MIGraphXExecutionProvider::default()
.with_device_id(id as _)
.with_fp16(self.migraphx_fp16)
.with_exhaustive_tune(self.migraphx_exhaustive_tune);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register MIGraphX: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "MIGraphX")),
}
}
}
Device::Vitis => {
#[cfg(not(feature = "vitis"))]
{
anyhow::bail!(feature_help
.replace("#EP", "VitisAI")
.replace("#FEATURE", "vitis"));
}
#[cfg(feature = "vitis")]
{
let ep = ort::execution_providers::VitisAIExecutionProvider::default()
.with_cache_dir(
crate::Dir::Cache
.crate_dir_default_with_subs(&["caches", "vitis"])?
.display()
.to_string(),
);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register VitisAI: {}", err)
})?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "VitisAI")),
}
}
}
Device::Azure => {
#[cfg(not(feature = "azure"))]
{
anyhow::bail!(feature_help
.replace("#EP", "Azure")
.replace("#FEATURE", "azure"));
}
#[cfg(feature = "azure")]
{
let ep = ort::execution_providers::AzureExecutionProvider::default();
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder).map_err(|err| {
anyhow::anyhow!("Failed to register Azure: {}", err)
})?;
builder = builder.with_extensions()?;
}
_ => anyhow::bail!(compile_help.replace("#EP", "Azure")),
}
}
}
_ => {
let ep = ort::execution_providers::CPUExecutionProvider::default()
.with_arena_allocator(true);
.with_arena_allocator(self.cpu_arena_allocator);
match ep.is_available() {
Ok(true) => {
ep.register(&mut builder)
@ -481,7 +928,8 @@ impl Engine {
};
let session = builder
.with_optimization_level(graph_opt_level)?
.with_intra_threads(std::thread::available_parallelism()?.get())?
.with_intra_threads(self.num_intra_threads.unwrap_or(n_threads_available))?
.with_inter_threads(self.num_inter_threads.unwrap_or(2))?
.commit_from_file(self.file())?;
Ok(session)

View File

@ -5,6 +5,7 @@ use log::{info, warn};
use rayon::prelude::*;
use std::collections::VecDeque;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::mpsc;
#[cfg(feature = "video")]
use video_rs::{Decoder, Url};
@ -80,9 +81,10 @@ impl std::fmt::Debug for DataLoader {
}
}
impl TryFrom<&str> for DataLoader {
type Error = anyhow::Error;
fn try_from(source: &str) -> Result<Self, Self::Error> {
impl FromStr for DataLoader {
type Err = anyhow::Error;
fn from_str(source: &str) -> Result<Self, Self::Err> {
Self::new(source)
}
}

View File

@ -460,7 +460,7 @@ impl Hub {
fn cache_file(owner: &str, repo: &str) -> String {
let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
format!(".cache-releases-{}-{}.json", safe_owner, safe_repo)
format!("releases-{}-{}.json", safe_owner, safe_repo)
}
fn get_releases(
@ -470,7 +470,9 @@ impl Hub {
to: &Dir,
ttl: &Duration,
) -> Result<Vec<Release>> {
let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo));
let cache = to
.crate_dir_default_with_subs(&["caches"])?
.join(Self::cache_file(owner, repo));
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
let body = if is_file_expired {
let gh_api_release = format!(

View File

@ -2,6 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
use rand::{prelude::*, rng};
use std::str::FromStr;
use crate::{
elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y,
@ -16,10 +17,10 @@ pub enum SamKind {
EdgeSam,
}
impl TryFrom<&str> for SamKind {
type Error = anyhow::Error;
impl FromStr for SamKind {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"sam" => Ok(Self::Sam),
"sam2" => Ok(Self::Sam2),

View File

@ -2,6 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
use std::str::FromStr;
use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y};
@ -11,10 +12,10 @@ pub enum TrOCRKind {
HandWritten,
}
impl TryFrom<&str> for TrOCRKind {
type Error = anyhow::Error;
impl FromStr for TrOCRKind {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"printed" => Ok(Self::Printed),
"handwritten" | "hand-written" => Ok(Self::HandWritten),

View File

@ -261,6 +261,507 @@ impl Config {
self
}
pub fn with_graph_opt_level_all(mut self, level: u8) -> Self {
self.visual = self.visual.with_graph_opt_level(level);
self.textual = self.textual.with_graph_opt_level(level);
self.model = self.model.with_graph_opt_level(level);
self.encoder = self.encoder.with_graph_opt_level(level);
self.decoder = self.decoder.with_graph_opt_level(level);
self.visual_encoder = self.visual_encoder.with_graph_opt_level(level);
self.textual_encoder = self.textual_encoder.with_graph_opt_level(level);
self.visual_decoder = self.visual_decoder.with_graph_opt_level(level);
self.textual_decoder = self.textual_decoder.with_graph_opt_level(level);
self.textual_decoder_merged = self.textual_decoder_merged.with_graph_opt_level(level);
self.size_encoder = self.size_encoder.with_graph_opt_level(level);
self.size_decoder = self.size_decoder.with_graph_opt_level(level);
self.coord_encoder = self.coord_encoder.with_graph_opt_level(level);
self.coord_decoder = self.coord_decoder.with_graph_opt_level(level);
self.visual_projection = self.visual_projection.with_graph_opt_level(level);
self.textual_projection = self.textual_projection.with_graph_opt_level(level);
self
}
pub fn with_num_intra_threads_all(mut self, num_threads: usize) -> Self {
self.visual = self.visual.with_num_intra_threads(num_threads);
self.textual = self.textual.with_num_intra_threads(num_threads);
self.model = self.model.with_num_intra_threads(num_threads);
self.encoder = self.encoder.with_num_intra_threads(num_threads);
self.decoder = self.decoder.with_num_intra_threads(num_threads);
self.visual_encoder = self.visual_encoder.with_num_intra_threads(num_threads);
self.textual_encoder = self.textual_encoder.with_num_intra_threads(num_threads);
self.visual_decoder = self.visual_decoder.with_num_intra_threads(num_threads);
self.textual_decoder = self.textual_decoder.with_num_intra_threads(num_threads);
self.textual_decoder_merged = self
.textual_decoder_merged
.with_num_intra_threads(num_threads);
self.size_encoder = self.size_encoder.with_num_intra_threads(num_threads);
self.size_decoder = self.size_decoder.with_num_intra_threads(num_threads);
self.coord_encoder = self.coord_encoder.with_num_intra_threads(num_threads);
self.coord_decoder = self.coord_decoder.with_num_intra_threads(num_threads);
self.visual_projection = self.visual_projection.with_num_intra_threads(num_threads);
self.textual_projection = self.textual_projection.with_num_intra_threads(num_threads);
self
}
pub fn with_num_inter_threads_all(mut self, num_threads: usize) -> Self {
self.visual = self.visual.with_num_inter_threads(num_threads);
self.textual = self.textual.with_num_inter_threads(num_threads);
self.model = self.model.with_num_inter_threads(num_threads);
self.encoder = self.encoder.with_num_inter_threads(num_threads);
self.decoder = self.decoder.with_num_inter_threads(num_threads);
self.visual_encoder = self.visual_encoder.with_num_inter_threads(num_threads);
self.textual_encoder = self.textual_encoder.with_num_inter_threads(num_threads);
self.visual_decoder = self.visual_decoder.with_num_inter_threads(num_threads);
self.textual_decoder = self.textual_decoder.with_num_inter_threads(num_threads);
self.textual_decoder_merged = self
.textual_decoder_merged
.with_num_inter_threads(num_threads);
self.size_encoder = self.size_encoder.with_num_inter_threads(num_threads);
self.size_decoder = self.size_decoder.with_num_inter_threads(num_threads);
self.coord_encoder = self.coord_encoder.with_num_inter_threads(num_threads);
self.coord_decoder = self.coord_decoder.with_num_inter_threads(num_threads);
self.visual_projection = self.visual_projection.with_num_inter_threads(num_threads);
self.textual_projection = self.textual_projection.with_num_inter_threads(num_threads);
self
}
pub fn with_cpu_arena_allocator_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_cpu_arena_allocator(x);
self.textual = self.textual.with_cpu_arena_allocator(x);
self.model = self.model.with_cpu_arena_allocator(x);
self.encoder = self.encoder.with_cpu_arena_allocator(x);
self.decoder = self.decoder.with_cpu_arena_allocator(x);
self.visual_encoder = self.visual_encoder.with_cpu_arena_allocator(x);
self.textual_encoder = self.textual_encoder.with_cpu_arena_allocator(x);
self.visual_decoder = self.visual_decoder.with_cpu_arena_allocator(x);
self.textual_decoder = self.textual_decoder.with_cpu_arena_allocator(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_cpu_arena_allocator(x);
self.size_encoder = self.size_encoder.with_cpu_arena_allocator(x);
self.size_decoder = self.size_decoder.with_cpu_arena_allocator(x);
self.coord_encoder = self.coord_encoder.with_cpu_arena_allocator(x);
self.coord_decoder = self.coord_decoder.with_cpu_arena_allocator(x);
self.visual_projection = self.visual_projection.with_cpu_arena_allocator(x);
self.textual_projection = self.textual_projection.with_cpu_arena_allocator(x);
self
}
pub fn with_openvino_dynamic_shapes_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_openvino_dynamic_shapes(x);
self.textual = self.textual.with_openvino_dynamic_shapes(x);
self.model = self.model.with_openvino_dynamic_shapes(x);
self.encoder = self.encoder.with_openvino_dynamic_shapes(x);
self.decoder = self.decoder.with_openvino_dynamic_shapes(x);
self.visual_encoder = self.visual_encoder.with_openvino_dynamic_shapes(x);
self.textual_encoder = self.textual_encoder.with_openvino_dynamic_shapes(x);
self.visual_decoder = self.visual_decoder.with_openvino_dynamic_shapes(x);
self.textual_decoder = self.textual_decoder.with_openvino_dynamic_shapes(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_dynamic_shapes(x);
self.size_encoder = self.size_encoder.with_openvino_dynamic_shapes(x);
self.size_decoder = self.size_decoder.with_openvino_dynamic_shapes(x);
self.coord_encoder = self.coord_encoder.with_openvino_dynamic_shapes(x);
self.coord_decoder = self.coord_decoder.with_openvino_dynamic_shapes(x);
self.visual_projection = self.visual_projection.with_openvino_dynamic_shapes(x);
self.textual_projection = self.textual_projection.with_openvino_dynamic_shapes(x);
self
}
pub fn with_openvino_opencl_throttling_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_openvino_opencl_throttling(x);
self.textual = self.textual.with_openvino_opencl_throttling(x);
self.model = self.model.with_openvino_opencl_throttling(x);
self.encoder = self.encoder.with_openvino_opencl_throttling(x);
self.decoder = self.decoder.with_openvino_opencl_throttling(x);
self.visual_encoder = self.visual_encoder.with_openvino_opencl_throttling(x);
self.textual_encoder = self.textual_encoder.with_openvino_opencl_throttling(x);
self.visual_decoder = self.visual_decoder.with_openvino_opencl_throttling(x);
self.textual_decoder = self.textual_decoder.with_openvino_opencl_throttling(x);
self.textual_decoder_merged = self
.textual_decoder_merged
.with_openvino_opencl_throttling(x);
self.size_encoder = self.size_encoder.with_openvino_opencl_throttling(x);
self.size_decoder = self.size_decoder.with_openvino_opencl_throttling(x);
self.coord_encoder = self.coord_encoder.with_openvino_opencl_throttling(x);
self.coord_decoder = self.coord_decoder.with_openvino_opencl_throttling(x);
self.visual_projection = self.visual_projection.with_openvino_opencl_throttling(x);
self.textual_projection = self.textual_projection.with_openvino_opencl_throttling(x);
self
}
pub fn with_openvino_qdq_optimizer_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_openvino_qdq_optimizer(x);
self.textual = self.textual.with_openvino_qdq_optimizer(x);
self.model = self.model.with_openvino_qdq_optimizer(x);
self.encoder = self.encoder.with_openvino_qdq_optimizer(x);
self.decoder = self.decoder.with_openvino_qdq_optimizer(x);
self.visual_encoder = self.visual_encoder.with_openvino_qdq_optimizer(x);
self.textual_encoder = self.textual_encoder.with_openvino_qdq_optimizer(x);
self.visual_decoder = self.visual_decoder.with_openvino_qdq_optimizer(x);
self.textual_decoder = self.textual_decoder.with_openvino_qdq_optimizer(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_qdq_optimizer(x);
self.size_encoder = self.size_encoder.with_openvino_qdq_optimizer(x);
self.size_decoder = self.size_decoder.with_openvino_qdq_optimizer(x);
self.coord_encoder = self.coord_encoder.with_openvino_qdq_optimizer(x);
self.coord_decoder = self.coord_decoder.with_openvino_qdq_optimizer(x);
self.visual_projection = self.visual_projection.with_openvino_qdq_optimizer(x);
self.textual_projection = self.textual_projection.with_openvino_qdq_optimizer(x);
self
}
pub fn with_openvino_num_threads_all(mut self, num_threads: usize) -> Self {
self.visual = self.visual.with_openvino_num_threads(num_threads);
self.textual = self.textual.with_openvino_num_threads(num_threads);
self.model = self.model.with_openvino_num_threads(num_threads);
self.encoder = self.encoder.with_openvino_num_threads(num_threads);
self.decoder = self.decoder.with_openvino_num_threads(num_threads);
self.visual_encoder = self.visual_encoder.with_openvino_num_threads(num_threads);
self.textual_encoder = self.textual_encoder.with_openvino_num_threads(num_threads);
self.visual_decoder = self.visual_decoder.with_openvino_num_threads(num_threads);
self.textual_decoder = self.textual_decoder.with_openvino_num_threads(num_threads);
self.textual_decoder_merged = self
.textual_decoder_merged
.with_openvino_num_threads(num_threads);
self.size_encoder = self.size_encoder.with_openvino_num_threads(num_threads);
self.size_decoder = self.size_decoder.with_openvino_num_threads(num_threads);
self.coord_encoder = self.coord_encoder.with_openvino_num_threads(num_threads);
self.coord_decoder = self.coord_decoder.with_openvino_num_threads(num_threads);
self.visual_projection = self
.visual_projection
.with_openvino_num_threads(num_threads);
self.textual_projection = self
.textual_projection
.with_openvino_num_threads(num_threads);
self
}
// onednn
pub fn with_onednn_arena_allocator_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_onednn_arena_allocator(x);
self.textual = self.textual.with_onednn_arena_allocator(x);
self.model = self.model.with_onednn_arena_allocator(x);
self.encoder = self.encoder.with_onednn_arena_allocator(x);
self.decoder = self.decoder.with_onednn_arena_allocator(x);
self.visual_encoder = self.visual_encoder.with_onednn_arena_allocator(x);
self.textual_encoder = self.textual_encoder.with_onednn_arena_allocator(x);
self.visual_decoder = self.visual_decoder.with_onednn_arena_allocator(x);
self.textual_decoder = self.textual_decoder.with_onednn_arena_allocator(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_onednn_arena_allocator(x);
self.size_encoder = self.size_encoder.with_onednn_arena_allocator(x);
self.size_decoder = self.size_decoder.with_onednn_arena_allocator(x);
self.coord_encoder = self.coord_encoder.with_onednn_arena_allocator(x);
self.coord_decoder = self.coord_decoder.with_onednn_arena_allocator(x);
self.visual_projection = self.visual_projection.with_onednn_arena_allocator(x);
self.textual_projection = self.textual_projection.with_onednn_arena_allocator(x);
self
}
// tensorrt
pub fn with_tensorrt_fp16_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_tensorrt_fp16(x);
self.textual = self.textual.with_tensorrt_fp16(x);
self.model = self.model.with_tensorrt_fp16(x);
self.encoder = self.encoder.with_tensorrt_fp16(x);
self.decoder = self.decoder.with_tensorrt_fp16(x);
self.visual_encoder = self.visual_encoder.with_tensorrt_fp16(x);
self.textual_encoder = self.textual_encoder.with_tensorrt_fp16(x);
self.visual_decoder = self.visual_decoder.with_tensorrt_fp16(x);
self.textual_decoder = self.textual_decoder.with_tensorrt_fp16(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_fp16(x);
self.size_encoder = self.size_encoder.with_tensorrt_fp16(x);
self.size_decoder = self.size_decoder.with_tensorrt_fp16(x);
self.coord_encoder = self.coord_encoder.with_tensorrt_fp16(x);
self.coord_decoder = self.coord_decoder.with_tensorrt_fp16(x);
self.visual_projection = self.visual_projection.with_tensorrt_fp16(x);
self.textual_projection = self.textual_projection.with_tensorrt_fp16(x);
self
}
pub fn with_tensorrt_engine_cache_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_tensorrt_engine_cache(x);
self.textual = self.textual.with_tensorrt_engine_cache(x);
self.model = self.model.with_tensorrt_engine_cache(x);
self.encoder = self.encoder.with_tensorrt_engine_cache(x);
self.decoder = self.decoder.with_tensorrt_engine_cache(x);
self.visual_encoder = self.visual_encoder.with_tensorrt_engine_cache(x);
self.textual_encoder = self.textual_encoder.with_tensorrt_engine_cache(x);
self.visual_decoder = self.visual_decoder.with_tensorrt_engine_cache(x);
self.textual_decoder = self.textual_decoder.with_tensorrt_engine_cache(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_engine_cache(x);
self.size_encoder = self.size_encoder.with_tensorrt_engine_cache(x);
self.size_decoder = self.size_decoder.with_tensorrt_engine_cache(x);
self.coord_encoder = self.coord_encoder.with_tensorrt_engine_cache(x);
self.coord_decoder = self.coord_decoder.with_tensorrt_engine_cache(x);
self.visual_projection = self.visual_projection.with_tensorrt_engine_cache(x);
self.textual_projection = self.textual_projection.with_tensorrt_engine_cache(x);
self
}
pub fn with_tensorrt_timing_cache_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_tensorrt_timing_cache(x);
self.textual = self.textual.with_tensorrt_timing_cache(x);
self.model = self.model.with_tensorrt_timing_cache(x);
self.encoder = self.encoder.with_tensorrt_timing_cache(x);
self.decoder = self.decoder.with_tensorrt_timing_cache(x);
self.visual_encoder = self.visual_encoder.with_tensorrt_timing_cache(x);
self.textual_encoder = self.textual_encoder.with_tensorrt_timing_cache(x);
self.visual_decoder = self.visual_decoder.with_tensorrt_timing_cache(x);
self.textual_decoder = self.textual_decoder.with_tensorrt_timing_cache(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_timing_cache(x);
self.size_encoder = self.size_encoder.with_tensorrt_timing_cache(x);
self.size_decoder = self.size_decoder.with_tensorrt_timing_cache(x);
self.coord_encoder = self.coord_encoder.with_tensorrt_timing_cache(x);
self.coord_decoder = self.coord_decoder.with_tensorrt_timing_cache(x);
self.visual_projection = self.visual_projection.with_tensorrt_timing_cache(x);
self.textual_projection = self.textual_projection.with_tensorrt_timing_cache(x);
self
}
// coreml
pub fn with_coreml_static_input_shapes_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_coreml_static_input_shapes(x);
self.textual = self.textual.with_coreml_static_input_shapes(x);
self.model = self.model.with_coreml_static_input_shapes(x);
self.encoder = self.encoder.with_coreml_static_input_shapes(x);
self.decoder = self.decoder.with_coreml_static_input_shapes(x);
self.visual_encoder = self.visual_encoder.with_coreml_static_input_shapes(x);
self.textual_encoder = self.textual_encoder.with_coreml_static_input_shapes(x);
self.visual_decoder = self.visual_decoder.with_coreml_static_input_shapes(x);
self.textual_decoder = self.textual_decoder.with_coreml_static_input_shapes(x);
self.textual_decoder_merged = self
.textual_decoder_merged
.with_coreml_static_input_shapes(x);
self.size_encoder = self.size_encoder.with_coreml_static_input_shapes(x);
self.size_decoder = self.size_decoder.with_coreml_static_input_shapes(x);
self.coord_encoder = self.coord_encoder.with_coreml_static_input_shapes(x);
self.coord_decoder = self.coord_decoder.with_coreml_static_input_shapes(x);
self.visual_projection = self.visual_projection.with_coreml_static_input_shapes(x);
self.textual_projection = self.textual_projection.with_coreml_static_input_shapes(x);
self
}
pub fn with_coreml_subgraph_running_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_coreml_subgraph_running(x);
self.textual = self.textual.with_coreml_subgraph_running(x);
self.model = self.model.with_coreml_subgraph_running(x);
self.encoder = self.encoder.with_coreml_subgraph_running(x);
self.decoder = self.decoder.with_coreml_subgraph_running(x);
self.visual_encoder = self.visual_encoder.with_coreml_subgraph_running(x);
self.textual_encoder = self.textual_encoder.with_coreml_subgraph_running(x);
self.visual_decoder = self.visual_decoder.with_coreml_subgraph_running(x);
self.textual_decoder = self.textual_decoder.with_coreml_subgraph_running(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_coreml_subgraph_running(x);
self.size_encoder = self.size_encoder.with_coreml_subgraph_running(x);
self.size_decoder = self.size_decoder.with_coreml_subgraph_running(x);
self.coord_encoder = self.coord_encoder.with_coreml_subgraph_running(x);
self.coord_decoder = self.coord_decoder.with_coreml_subgraph_running(x);
self.visual_projection = self.visual_projection.with_coreml_subgraph_running(x);
self.textual_projection = self.textual_projection.with_coreml_subgraph_running(x);
self
}
// cann
pub fn with_cann_graph_inference_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_cann_graph_inference(x);
self.textual = self.textual.with_cann_graph_inference(x);
self.model = self.model.with_cann_graph_inference(x);
self.encoder = self.encoder.with_cann_graph_inference(x);
self.decoder = self.decoder.with_cann_graph_inference(x);
self.visual_encoder = self.visual_encoder.with_cann_graph_inference(x);
self.textual_encoder = self.textual_encoder.with_cann_graph_inference(x);
self.visual_decoder = self.visual_decoder.with_cann_graph_inference(x);
self.textual_decoder = self.textual_decoder.with_cann_graph_inference(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_graph_inference(x);
self.size_encoder = self.size_encoder.with_cann_graph_inference(x);
self.size_decoder = self.size_decoder.with_cann_graph_inference(x);
self.coord_encoder = self.coord_encoder.with_cann_graph_inference(x);
self.coord_decoder = self.coord_decoder.with_cann_graph_inference(x);
self.visual_projection = self.visual_projection.with_cann_graph_inference(x);
self.textual_projection = self.textual_projection.with_cann_graph_inference(x);
self
}
pub fn with_cann_dump_graphs_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_cann_dump_graphs(x);
self.textual = self.textual.with_cann_dump_graphs(x);
self.model = self.model.with_cann_dump_graphs(x);
self.encoder = self.encoder.with_cann_dump_graphs(x);
self.decoder = self.decoder.with_cann_dump_graphs(x);
self.visual_encoder = self.visual_encoder.with_cann_dump_graphs(x);
self.textual_encoder = self.textual_encoder.with_cann_dump_graphs(x);
self.visual_decoder = self.visual_decoder.with_cann_dump_graphs(x);
self.textual_decoder = self.textual_decoder.with_cann_dump_graphs(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_graphs(x);
self.size_encoder = self.size_encoder.with_cann_dump_graphs(x);
self.size_decoder = self.size_decoder.with_cann_dump_graphs(x);
self.coord_encoder = self.coord_encoder.with_cann_dump_graphs(x);
self.coord_decoder = self.coord_decoder.with_cann_dump_graphs(x);
self.visual_projection = self.visual_projection.with_cann_dump_graphs(x);
self.textual_projection = self.textual_projection.with_cann_dump_graphs(x);
self
}
pub fn with_cann_dump_om_model_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_cann_dump_om_model(x);
self.textual = self.textual.with_cann_dump_om_model(x);
self.model = self.model.with_cann_dump_om_model(x);
self.encoder = self.encoder.with_cann_dump_om_model(x);
self.decoder = self.decoder.with_cann_dump_om_model(x);
self.visual_encoder = self.visual_encoder.with_cann_dump_om_model(x);
self.textual_encoder = self.textual_encoder.with_cann_dump_om_model(x);
self.visual_decoder = self.visual_decoder.with_cann_dump_om_model(x);
self.textual_decoder = self.textual_decoder.with_cann_dump_om_model(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_om_model(x);
self.size_encoder = self.size_encoder.with_cann_dump_om_model(x);
self.size_decoder = self.size_decoder.with_cann_dump_om_model(x);
self.coord_encoder = self.coord_encoder.with_cann_dump_om_model(x);
self.coord_decoder = self.coord_decoder.with_cann_dump_om_model(x);
self.visual_projection = self.visual_projection.with_cann_dump_om_model(x);
self.textual_projection = self.textual_projection.with_cann_dump_om_model(x);
self
}
// nnapi
pub fn with_nnapi_cpu_only_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_nnapi_cpu_only(x);
self.textual = self.textual.with_nnapi_cpu_only(x);
self.model = self.model.with_nnapi_cpu_only(x);
self.encoder = self.encoder.with_nnapi_cpu_only(x);
self.decoder = self.decoder.with_nnapi_cpu_only(x);
self.visual_encoder = self.visual_encoder.with_nnapi_cpu_only(x);
self.textual_encoder = self.textual_encoder.with_nnapi_cpu_only(x);
self.visual_decoder = self.visual_decoder.with_nnapi_cpu_only(x);
self.textual_decoder = self.textual_decoder.with_nnapi_cpu_only(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_cpu_only(x);
self.size_encoder = self.size_encoder.with_nnapi_cpu_only(x);
self.size_decoder = self.size_decoder.with_nnapi_cpu_only(x);
self.coord_encoder = self.coord_encoder.with_nnapi_cpu_only(x);
self.coord_decoder = self.coord_decoder.with_nnapi_cpu_only(x);
self.visual_projection = self.visual_projection.with_nnapi_cpu_only(x);
self.textual_projection = self.textual_projection.with_nnapi_cpu_only(x);
self
}
pub fn with_nnapi_disable_cpu_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_nnapi_disable_cpu(x);
self.textual = self.textual.with_nnapi_disable_cpu(x);
self.model = self.model.with_nnapi_disable_cpu(x);
self.encoder = self.encoder.with_nnapi_disable_cpu(x);
self.decoder = self.decoder.with_nnapi_disable_cpu(x);
self.visual_encoder = self.visual_encoder.with_nnapi_disable_cpu(x);
self.textual_encoder = self.textual_encoder.with_nnapi_disable_cpu(x);
self.visual_decoder = self.visual_decoder.with_nnapi_disable_cpu(x);
self.textual_decoder = self.textual_decoder.with_nnapi_disable_cpu(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_disable_cpu(x);
self.size_encoder = self.size_encoder.with_nnapi_disable_cpu(x);
self.size_decoder = self.size_decoder.with_nnapi_disable_cpu(x);
self.coord_encoder = self.coord_encoder.with_nnapi_disable_cpu(x);
self.coord_decoder = self.coord_decoder.with_nnapi_disable_cpu(x);
self.visual_projection = self.visual_projection.with_nnapi_disable_cpu(x);
self.textual_projection = self.textual_projection.with_nnapi_disable_cpu(x);
self
}
pub fn with_nnapi_fp16_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_nnapi_fp16(x);
self.textual = self.textual.with_nnapi_fp16(x);
self.model = self.model.with_nnapi_fp16(x);
self.encoder = self.encoder.with_nnapi_fp16(x);
self.decoder = self.decoder.with_nnapi_fp16(x);
self.visual_encoder = self.visual_encoder.with_nnapi_fp16(x);
self.textual_encoder = self.textual_encoder.with_nnapi_fp16(x);
self.visual_decoder = self.visual_decoder.with_nnapi_fp16(x);
self.textual_decoder = self.textual_decoder.with_nnapi_fp16(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_fp16(x);
self.size_encoder = self.size_encoder.with_nnapi_fp16(x);
self.size_decoder = self.size_decoder.with_nnapi_fp16(x);
self.coord_encoder = self.coord_encoder.with_nnapi_fp16(x);
self.coord_decoder = self.coord_decoder.with_nnapi_fp16(x);
self.visual_projection = self.visual_projection.with_nnapi_fp16(x);
self.textual_projection = self.textual_projection.with_nnapi_fp16(x);
self
}
pub fn with_nnapi_nchw_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_nnapi_nchw(x);
self.textual = self.textual.with_nnapi_nchw(x);
self.model = self.model.with_nnapi_nchw(x);
self.encoder = self.encoder.with_nnapi_nchw(x);
self.decoder = self.decoder.with_nnapi_nchw(x);
self.visual_encoder = self.visual_encoder.with_nnapi_nchw(x);
self.textual_encoder = self.textual_encoder.with_nnapi_nchw(x);
self.visual_decoder = self.visual_decoder.with_nnapi_nchw(x);
self.textual_decoder = self.textual_decoder.with_nnapi_nchw(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_nchw(x);
self.size_encoder = self.size_encoder.with_nnapi_nchw(x);
self.size_decoder = self.size_decoder.with_nnapi_nchw(x);
self.coord_encoder = self.coord_encoder.with_nnapi_nchw(x);
self.coord_decoder = self.coord_decoder.with_nnapi_nchw(x);
self.visual_projection = self.visual_projection.with_nnapi_nchw(x);
self.textual_projection = self.textual_projection.with_nnapi_nchw(x);
self
}
// armnn
pub fn with_armnn_arena_allocator_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_armnn_arena_allocator(x);
self.textual = self.textual.with_armnn_arena_allocator(x);
self.model = self.model.with_armnn_arena_allocator(x);
self.encoder = self.encoder.with_armnn_arena_allocator(x);
self.decoder = self.decoder.with_armnn_arena_allocator(x);
self.visual_encoder = self.visual_encoder.with_armnn_arena_allocator(x);
self.textual_encoder = self.textual_encoder.with_armnn_arena_allocator(x);
self.visual_decoder = self.visual_decoder.with_armnn_arena_allocator(x);
self.textual_decoder = self.textual_decoder.with_armnn_arena_allocator(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_armnn_arena_allocator(x);
self.size_encoder = self.size_encoder.with_armnn_arena_allocator(x);
self.size_decoder = self.size_decoder.with_armnn_arena_allocator(x);
self.coord_encoder = self.coord_encoder.with_armnn_arena_allocator(x);
self.coord_decoder = self.coord_decoder.with_armnn_arena_allocator(x);
self.visual_projection = self.visual_projection.with_armnn_arena_allocator(x);
self.textual_projection = self.textual_projection.with_armnn_arena_allocator(x);
self
}
// migraphx
pub fn with_migraphx_fp16_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_migraphx_fp16(x);
self.textual = self.textual.with_migraphx_fp16(x);
self.model = self.model.with_migraphx_fp16(x);
self.encoder = self.encoder.with_migraphx_fp16(x);
self.decoder = self.decoder.with_migraphx_fp16(x);
self.visual_encoder = self.visual_encoder.with_migraphx_fp16(x);
self.textual_encoder = self.textual_encoder.with_migraphx_fp16(x);
self.visual_decoder = self.visual_decoder.with_migraphx_fp16(x);
self.textual_decoder = self.textual_decoder.with_migraphx_fp16(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_fp16(x);
self.size_encoder = self.size_encoder.with_migraphx_fp16(x);
self.size_decoder = self.size_decoder.with_migraphx_fp16(x);
self.coord_encoder = self.coord_encoder.with_migraphx_fp16(x);
self.coord_decoder = self.coord_decoder.with_migraphx_fp16(x);
self.visual_projection = self.visual_projection.with_migraphx_fp16(x);
self.textual_projection = self.textual_projection.with_migraphx_fp16(x);
self
}
pub fn with_migraphx_exhaustive_tune_all(mut self, x: bool) -> Self {
self.visual = self.visual.with_migraphx_exhaustive_tune(x);
self.textual = self.textual.with_migraphx_exhaustive_tune(x);
self.model = self.model.with_migraphx_exhaustive_tune(x);
self.encoder = self.encoder.with_migraphx_exhaustive_tune(x);
self.decoder = self.decoder.with_migraphx_exhaustive_tune(x);
self.visual_encoder = self.visual_encoder.with_migraphx_exhaustive_tune(x);
self.textual_encoder = self.textual_encoder.with_migraphx_exhaustive_tune(x);
self.visual_decoder = self.visual_decoder.with_migraphx_exhaustive_tune(x);
self.textual_decoder = self.textual_decoder.with_migraphx_exhaustive_tune(x);
self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_exhaustive_tune(x);
self.size_encoder = self.size_encoder.with_migraphx_exhaustive_tune(x);
self.size_decoder = self.size_decoder.with_migraphx_exhaustive_tune(x);
self.coord_encoder = self.coord_encoder.with_migraphx_exhaustive_tune(x);
self.coord_decoder = self.coord_decoder.with_migraphx_exhaustive_tune(x);
self.visual_projection = self.visual_projection.with_migraphx_exhaustive_tune(x);
self.textual_projection = self.textual_projection.with_migraphx_exhaustive_tune(x);
self
}
}
impl_ort_config_methods!(Config, model);

View File

@ -3,7 +3,22 @@ pub enum Device {
Cpu(usize),
Cuda(usize),
TensorRt(usize),
CoreMl(usize),
OpenVino(&'static str),
DirectMl(usize),
Cann(usize),
Rocm(usize),
Qnn(usize),
MiGraphX(usize),
CoreMl,
Xnnpack,
RkNpu,
OneDnn,
Acl,
NnApi,
ArmNn,
Tvm,
Vitis,
Azure,
}
impl Default for Device {
@ -18,41 +33,97 @@ impl std::fmt::Display for Device {
Self::Cpu(i) => format!("CPU:{}", i),
Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i),
Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i),
Self::CoreMl(i) => format!("CoreML:{}(Apple)", i),
Self::Cann(i) => format!("CANN:{}(Huawei)", i),
Self::OpenVino(s) => format!("OpenVINO:{}(Intel)", s),
Self::DirectMl(i) => format!("DirectML:{}(Microsoft)", i),
Self::Qnn(i) => format!("QNN:{}(Qualcomm)", i),
Self::MiGraphX(i) => format!("MIGraphX:{}(AMD)", i),
Self::Rocm(i) => format!("ROCm:{}(AMD)", i),
Self::CoreMl => "CoreML(Apple)".to_string(),
Self::Azure => "Azure(Microsoft)".to_string(),
Self::Xnnpack => "XNNPACK".to_string(),
Self::OneDnn => "oneDNN(Intel)".to_string(),
Self::RkNpu => "RKNPU".to_string(),
Self::Acl => "ACL(Arm)".to_string(),
Self::NnApi => "NNAPI(Android)".to_string(),
Self::ArmNn => "ArmNN(Arm)".to_string(),
Self::Tvm => "TVM(Apache)".to_string(),
Self::Vitis => "VitisAI(AMD)".to_string(),
};
write!(f, "{}", x)
}
}
impl TryFrom<&str> for Device {
type Error = anyhow::Error;
impl std::str::FromStr for Device {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
// device and its id
let d_id: Vec<&str> = s.trim().split(':').collect();
let (d, id) = match d_id.len() {
1 => (d_id[0].trim(), 0),
2 => (d_id[0].trim(), d_id[1].trim().parse::<usize>().unwrap_or(0)),
_ => anyhow::bail!(
"Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`"
),
};
// TODO: device-id checking
match d.to_lowercase().as_str() {
"cpu" => Ok(Self::Cpu(id)),
"cuda" => Ok(Self::Cuda(id)),
"trt" | "tensorrt" => Ok(Self::TensorRt(id)),
"coreml" | "mps" => Ok(Self::CoreMl(id)),
fn from_str(s: &str) -> Result<Self, Self::Err> {
#[inline]
fn parse_device_id(id_str: Option<&str>) -> usize {
id_str
.map(|s| s.trim().parse::<usize>().unwrap_or(0))
.unwrap_or(0)
}
// Use split_once for better performance - no Vec allocation
let (device_type, id_part) = s
.trim()
.split_once(':')
.map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
match device_type.to_lowercase().as_str() {
"cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
"cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
"trt" | "tensorrt" => Ok(Self::TensorRt(parse_device_id(id_part))),
"coreml" | "mps" => Ok(Self::CoreMl),
"openvino" => {
// For OpenVino, use the user input directly after first colon (trimmed)
let device_spec = id_part.map(|s| s.trim()).unwrap_or("CPU"); // Default to CPU if no specification provided
Ok(Self::OpenVino(Box::leak(
device_spec.to_string().into_boxed_str(),
)))
}
"directml" => Ok(Self::DirectMl(parse_device_id(id_part))),
"xnnpack" => Ok(Self::Xnnpack),
"cann" => Ok(Self::Cann(parse_device_id(id_part))),
"rknpu" => Ok(Self::RkNpu),
"onednn" => Ok(Self::OneDnn),
"acl" => Ok(Self::Acl),
"rocm" => Ok(Self::Rocm(parse_device_id(id_part))),
"nnapi" => Ok(Self::NnApi),
"armnn" => Ok(Self::ArmNn),
"tvm" => Ok(Self::Tvm),
"qnn" => Ok(Self::Qnn(parse_device_id(id_part))),
"migraphx" => Ok(Self::MiGraphX(parse_device_id(id_part))),
"vitisai" => Ok(Self::Vitis),
"azure" => Ok(Self::Azure),
_ => anyhow::bail!("Unsupported device str: {s:?}."),
}
}
}
impl Device {
pub fn id(&self) -> usize {
pub fn id(&self) -> Option<usize> {
match self {
Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i,
Self::Cpu(i)
| Self::Cuda(i)
| Self::TensorRt(i)
| Self::Cann(i)
| Self::Qnn(i)
| Self::Rocm(i)
| Self::MiGraphX(i)
| Self::DirectMl(i) => Some(*i),
Self::OpenVino(_)
| Self::Xnnpack
| Self::CoreMl
| Self::RkNpu
| Self::OneDnn
| Self::NnApi
| Self::Azure
| Self::Vitis
| Self::ArmNn
| Self::Tvm
| Self::Acl => None,
}
}
}

View File

@ -29,10 +29,10 @@ pub enum DType {
Complex128,
}
impl TryFrom<&str> for DType {
type Error = anyhow::Error;
impl std::str::FromStr for DType {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"auto" | "dyn" => Ok(Self::Auto),
"u4" | "uint4" => Ok(Self::Uint4),

View File

@ -9,10 +9,42 @@ pub struct ORTConfig {
pub device: Device,
pub iiixs: Vec<Iiix>,
pub num_dry_run: usize,
pub trt_fp16: bool,
pub graph_opt_level: Option<u8>,
pub spec: String, // TODO: move out
pub dtype: DType, // For dynamically loading the model
// global
pub graph_opt_level: Option<u8>,
pub num_intra_threads: Option<usize>,
pub num_inter_threads: Option<usize>,
// cpu
pub cpu_arena_allocator: bool,
// openvino
pub openvino_dynamic_shapes: bool,
pub openvino_opencl_throttling: bool,
pub openvino_qdq_optimizer: bool,
pub openvino_num_threads: Option<usize>,
// onednn
pub onednn_arena_allocator: bool,
// tensorrt
pub tensorrt_fp16: bool,
pub tensorrt_engine_cache: bool,
pub tensorrt_timing_cache: bool,
// coreml
pub coreml_static_input_shapes: bool,
pub coreml_subgraph_running: bool,
// cann
pub cann_graph_inference: bool,
pub cann_dump_graphs: bool,
pub cann_dump_om_model: bool,
// nnapi
pub nnapi_cpu_only: bool,
pub nnapi_disable_cpu: bool,
pub nnapi_fp16: bool,
pub nnapi_nchw: bool,
// armnn
pub armnn_arena_allocator: bool,
// migraphx
pub migraphx_fp16: bool,
pub migraphx_exhaustive_tune: bool,
}
impl Default for ORTConfig {
@ -21,11 +53,33 @@ impl Default for ORTConfig {
file: Default::default(),
device: Default::default(),
iiixs: Default::default(),
graph_opt_level: Default::default(),
spec: Default::default(),
dtype: Default::default(),
num_dry_run: 3,
trt_fp16: true,
graph_opt_level: Default::default(),
num_intra_threads: None,
num_inter_threads: None,
cpu_arena_allocator: true,
openvino_dynamic_shapes: true,
openvino_opencl_throttling: true,
openvino_qdq_optimizer: true,
openvino_num_threads: None,
coreml_static_input_shapes: false,
coreml_subgraph_running: true,
tensorrt_fp16: true,
tensorrt_engine_cache: true,
tensorrt_timing_cache: false,
cann_graph_inference: true,
cann_dump_graphs: false,
cann_dump_om_model: false,
onednn_arena_allocator: true,
nnapi_cpu_only: false,
nnapi_disable_cpu: false,
nnapi_fp16: true,
nnapi_nchw: false,
armnn_arena_allocator: true,
migraphx_fp16: true,
migraphx_exhaustive_tune: false,
}
}
}
@ -122,10 +176,6 @@ macro_rules! impl_ort_config_methods {
self.$field = self.$field.with_device(device);
self
}
pub fn [<with_ $field _trt_fp16>](mut self, x: bool) -> Self {
self.$field = self.$field.with_trt_fp16(x);
self
}
pub fn [<with_ $field _num_dry_run>](mut self, x: usize) -> Self {
self.$field = self.$field.with_num_dry_run(x);
self
@ -134,6 +184,113 @@ macro_rules! impl_ort_config_methods {
self.$field = self.$field.with_ixx(i, ii, x);
self
}
// global
pub fn [<with_ $field _graph_opt_level>](mut self, x: u8) -> Self {
self.$field = self.$field.with_graph_opt_level(x);
self
}
pub fn [<with_ $field _num_intra_threads>](mut self, x: usize) -> Self {
self.$field = self.$field.with_num_intra_threads(x);
self
}
pub fn [<with_ $field _num_inter_threads>](mut self, x: usize) -> Self {
self.$field = self.$field.with_num_inter_threads(x);
self
}
// cpu
pub fn [<with_ $field _cpu_arena_allocator>](mut self, x: bool) -> Self {
self.$field = self.$field.with_cpu_arena_allocator(x);
self
}
// openvino
pub fn [<with_ $field _openvino_dynamic_shapes>](mut self, x: bool) -> Self {
self.$field = self.$field.with_openvino_dynamic_shapes(x);
self
}
pub fn [<with_ $field _openvino_opencl_throttling>](mut self, x: bool) -> Self {
self.$field = self.$field.with_openvino_opencl_throttling(x);
self
}
pub fn [<with_ $field _openvino_qdq_optimizer>](mut self, x: bool) -> Self {
self.$field = self.$field.with_openvino_qdq_optimizer(x);
self
}
pub fn [<with_ $field _openvino_num_threads>](mut self, x: usize) -> Self {
self.$field = self.$field.with_openvino_num_threads(x);
self
}
// onednn
pub fn [<with_ $field _onednn_arena_allocator>](mut self, x: bool) -> Self {
self.$field = self.$field.with_onednn_arena_allocator(x);
self
}
// tensorrt
pub fn [<with_ $field _tensorrt_fp16>](mut self, x: bool) -> Self {
self.$field = self.$field.with_tensorrt_fp16(x);
self
}
pub fn [<with_ $field _tensorrt_engine_cache>](mut self, x: bool) -> Self {
self.$field = self.$field.with_tensorrt_engine_cache(x);
self
}
pub fn [<with_ $field _tensorrt_timing_cache>](mut self, x: bool) -> Self {
self.$field = self.$field.with_tensorrt_timing_cache(x);
self
}
// coreml
pub fn [<with_ $field _coreml_static_input_shapes>](mut self, x: bool) -> Self {
self.$field = self.$field.with_coreml_static_input_shapes(x);
self
}
pub fn [<with_ $field _coreml_subgraph_running>](mut self, x: bool) -> Self {
self.$field = self.$field.with_coreml_subgraph_running(x);
self
}
// cann
pub fn [<with_ $field _cann_graph_inference>](mut self, x: bool) -> Self {
self.$field = self.$field.with_cann_graph_inference(x);
self
}
pub fn [<with_ $field _cann_dump_graphs>](mut self, x: bool) -> Self {
self.$field = self.$field.with_cann_dump_graphs(x);
self
}
pub fn [<with_ $field _cann_dump_om_model>](mut self, x: bool) -> Self {
self.$field = self.$field.with_cann_dump_om_model(x);
self
}
// nnapi
pub fn [<with_ $field _nnapi_cpu_only>](mut self, x: bool) -> Self {
self.$field = self.$field.with_nnapi_cpu_only(x);
self
}
pub fn [<with_ $field _nnapi_disable_cpu>](mut self, x: bool) -> Self {
self.$field = self.$field.with_nnapi_disable_cpu(x);
self
}
pub fn [<with_ $field _nnapi_fp16>](mut self, x: bool) -> Self {
self.$field = self.$field.with_nnapi_fp16(x);
self
}
pub fn [<with_ $field _nnapi_nchw>](mut self, x: bool) -> Self {
self.$field = self.$field.with_nnapi_nchw(x);
self
}
// armnn
pub fn [<with_ $field _armnn_arena_allocator>](mut self, x: bool) -> Self {
self.$field = self.$field.with_armnn_arena_allocator(x);
self
}
// migraphx
pub fn [<with_ $field _migraphx_fp16>](mut self, x: bool) -> Self {
self.$field = self.$field.with_migraphx_fp16(x);
self
}
pub fn [<with_ $field _migraphx_exhaustive_tune>](mut self, x: bool) -> Self {
self.$field = self.$field.with_migraphx_exhaustive_tune(x);
self
}
}
}
};

View File

@ -1,3 +1,5 @@
use std::str::FromStr;
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum Scale {
N,
@ -64,10 +66,10 @@ impl TryFrom<char> for Scale {
}
}
impl TryFrom<&str> for Scale {
type Error = anyhow::Error;
impl FromStr for Scale {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"n" | "nano" => Ok(Self::N),
"t" | "tiny" => Ok(Self::T),

View File

@ -1,3 +1,5 @@
use std::str::FromStr;
#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
pub enum Task {
/// Image classification task.
@ -164,10 +166,10 @@ impl std::fmt::Display for Task {
}
}
impl TryFrom<&str> for Task {
type Error = anyhow::Error;
impl FromStr for Task {
type Err = anyhow::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
fn from_str(s: &str) -> Result<Self, Self::Err> {
// TODO
match s.to_lowercase().as_str() {
"cls" | "classify" | "classification" => Ok(Self::ImageClassification),

View File

@ -61,20 +61,24 @@ impl From<Color> for [u8; 3] {
}
}
impl TryFrom<&str> for Color {
type Error = &'static str;
impl std::str::FromStr for Color {
type Err = anyhow::Error;
fn try_from(x: &str) -> Result<Self, Self::Error> {
fn from_str(x: &str) -> Result<Self, Self::Err> {
let hex = x.trim_start_matches('#');
let hex = match hex.len() {
6 => format!("{}ff", hex),
8 => hex.to_string(),
_ => return Err("Failed to convert `Color` from str: invalid length"),
_ => {
return Err(anyhow::anyhow!(
"Failed to convert `Color` from str: invalid length"
))
}
};
u32::from_str_radix(&hex, 16)
.map(Self)
.map_err(|_| "Failed to convert `Color` from str: invalid hex")
.map_err(|_| anyhow::anyhow!("Failed to convert `Color` from str: invalid hex"))
}
}
@ -151,17 +155,8 @@ impl Color {
xs.iter().copied().map(Into::into).collect()
}
pub fn try_create_palette<A: TryInto<Self> + Copy>(xs: &[A]) -> Result<Vec<Self>>
where
<A as TryInto<Self>>::Error: std::fmt::Debug,
{
xs.iter()
.copied()
.map(|x| {
x.try_into()
.map_err(|e| anyhow::anyhow!("Failed to convert: {:?}", e))
})
.collect()
pub fn try_create_palette(xs: &[&str]) -> Result<Vec<Self>> {
xs.iter().map(|x| x.parse()).collect()
}
pub fn palette_rand(n: usize) -> Vec<Self> {

View File

@ -14,20 +14,22 @@ pub enum ColorMap256 {
SmoothCoolWarm,
}
impl From<&str> for ColorMap256 {
fn from(s: &str) -> Self {
impl std::str::FromStr for ColorMap256 {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"turbo" => Self::Turbo,
"inferno" => Self::Inferno,
"plasma" => Self::Plasma,
"viridis" => Self::Viridis,
"magma" => Self::Magma,
"bentcoolwarm" => Self::BentCoolWarm,
"blackbody" => Self::BlackBody,
"extendedkindlmann" => Self::ExtendedKindLmann,
"kindlmann" => Self::KindLmann,
"smoothcoolwarm" => Self::SmoothCoolWarm,
s => unimplemented!("{} is not supported for now!", s),
"turbo" => Ok(Self::Turbo),
"inferno" => Ok(Self::Inferno),
"plasma" => Ok(Self::Plasma),
"viridis" => Ok(Self::Viridis),
"magma" => Ok(Self::Magma),
"bentcoolwarm" => Ok(Self::BentCoolWarm),
"blackbody" => Ok(Self::BlackBody),
"extendedkindlmann" => Ok(Self::ExtendedKindLmann),
"kindlmann" => Ok(Self::KindLmann),
"smoothcoolwarm" => Ok(Self::SmoothCoolWarm),
_ => Err(anyhow::anyhow!("Unsupported colormap: {}", s)),
}
}
}