mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Add some eps (#108)
This commit is contained in:
21
Cargo.toml
21
Cargo.toml
@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "usls"
|
||||
edition = "2021"
|
||||
version = "0.1.0-beta.3"
|
||||
version = "0.1.0-beta.4"
|
||||
rust-version = "1.82"
|
||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||
repository = "https://github.com/jamjamjon/usls"
|
||||
@ -45,6 +45,7 @@ ort = { version = "=2.0.0-rc.10", default-features = false, optional = true, fea
|
||||
] }
|
||||
tokenizers = { version = "0.21.1" }
|
||||
paste = "1.0.15"
|
||||
base64ct = "=1.7.3"
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.13.5"
|
||||
@ -53,11 +54,27 @@ prost-build = "0.13.5"
|
||||
argh = "0.1.13"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "chrono"] }
|
||||
|
||||
|
||||
[features]
|
||||
default = [ "ort-download-binaries" ]
|
||||
video = [ "dep:video-rs" ]
|
||||
ort-download-binaries = [ "ort", "ort/download-binaries" ]
|
||||
ort-load-dynamic = [ "ort", "ort/load-dynamic" ]
|
||||
cuda = [ "ort/cuda" ]
|
||||
trt = [ "ort/tensorrt" ]
|
||||
tensorrt = [ "ort/tensorrt" ]
|
||||
coreml = [ "ort/coreml" ]
|
||||
openvino = [ "ort/openvino" ]
|
||||
onednn = [ "ort/onednn" ]
|
||||
directml = [ "ort/directml" ]
|
||||
xnnpack = [ "ort/xnnpack" ]
|
||||
cann = [ "ort/cann" ]
|
||||
rknpu = [ "ort/rknpu" ]
|
||||
acl = [ "ort/acl" ]
|
||||
rocm = [ "ort/rocm" ]
|
||||
nnapi = [ "ort/nnapi" ]
|
||||
armnn = [ "ort/armnn" ]
|
||||
tvm = [ "ort/tvm" ]
|
||||
qnn = [ "ort/qnn" ]
|
||||
migraphx = [ "ort/migraphx" ]
|
||||
vitis = [ "ort/vitis" ]
|
||||
azure = [ "ort/azure" ]
|
||||
|
@ -21,8 +21,8 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::ben2_base()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = RMBG::new(config)?;
|
||||
|
||||
|
@ -21,7 +21,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::blip_v1_base_caption()
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = Blip::new(config)?;
|
||||
|
||||
|
@ -46,8 +46,8 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let config = config
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = ImageClassifier::try_from(config)?;
|
||||
|
||||
|
@ -29,8 +29,8 @@ fn main() -> Result<()> {
|
||||
// clip_vit_b32()
|
||||
// jina_clip_v1()
|
||||
// jina_clip_v2()
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_dtype_all(args.dtype.parse()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = Clip::new(config)?;
|
||||
|
||||
|
@ -47,9 +47,9 @@ fn main() -> Result<()> {
|
||||
// build model
|
||||
let config = match &args.model {
|
||||
Some(m) => Config::db().with_model_file(m),
|
||||
None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.as_str().try_into()?),
|
||||
None => Config::ppocr_det_v5_mobile().with_model_dtype(args.dtype.parse()?),
|
||||
}
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = DB::new(config)?;
|
||||
|
||||
|
@ -18,7 +18,7 @@ fn main() -> Result<()> {
|
||||
|
||||
// annotate
|
||||
let annotator =
|
||||
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into()));
|
||||
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?));
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
|
@ -24,8 +24,8 @@ fn main() -> Result<()> {
|
||||
|
||||
// model
|
||||
let config = Config::depth_pro()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
|
||||
let mut model = DepthPro::new(config)?;
|
||||
@ -38,7 +38,7 @@ fn main() -> Result<()> {
|
||||
|
||||
// annotate
|
||||
let annotator =
|
||||
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".into()));
|
||||
Annotator::default().with_mask_style(Style::mask().with_colormap256("turbo".parse()?));
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
|
@ -19,7 +19,7 @@ fn main() -> Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::doclayout_yolo_docstructbench()
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = YOLO::new(config)?;
|
||||
|
||||
|
@ -26,7 +26,7 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::T => Config::fast_tiny(),
|
||||
Scale::S => Config::fast_small(),
|
||||
Scale::B => Config::fast_base(),
|
||||
@ -34,8 +34,8 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let mut model = DB::new(
|
||||
config
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_dtype_all(args.dtype.parse()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?,
|
||||
)?;
|
||||
|
||||
|
@ -23,8 +23,8 @@ fn main() -> Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::fastsam_s()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = YOLO::new(config)?;
|
||||
|
||||
|
@ -26,8 +26,8 @@ fn main() -> Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::florence2_base()
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_dtype_all(args.dtype.parse()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.with_batch_size_all(xs.len())
|
||||
.commit()?;
|
||||
let mut model = Florence2::new(config)?;
|
||||
|
@ -46,8 +46,8 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
let config = Config::grounding_dino_tiny()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
|
||||
.with_class_confs(&[0.25])
|
||||
.with_text_confs(&[0.25])
|
||||
|
@ -27,7 +27,7 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::T => Config::linknet_r18(),
|
||||
Scale::S => Config::linknet_r34(),
|
||||
Scale::B => Config::linknet_r50(),
|
||||
@ -35,8 +35,8 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let mut model = DB::new(
|
||||
config
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?,
|
||||
)?;
|
||||
|
||||
|
@ -39,13 +39,13 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::Billion(0.5) => Config::moondream2_0_5b(),
|
||||
Scale::Billion(2.) => Config::moondream2_2b(),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_dtype_all(args.dtype.parse()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
|
||||
let mut model = Moondream2::new(config)?;
|
||||
@ -54,7 +54,7 @@ fn main() -> Result<()> {
|
||||
let xs = DataLoader::try_read_n(&args.source)?;
|
||||
|
||||
// run with task
|
||||
let task: Task = args.task.as_str().try_into()?;
|
||||
let task: Task = args.task.parse()?;
|
||||
let ys = model.forward(&xs, &task)?;
|
||||
|
||||
// annotate
|
||||
|
@ -49,8 +49,8 @@ fn main() -> Result<()> {
|
||||
// config
|
||||
let config = Config::owlv2_base_ensemble()
|
||||
// owlv2_base()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
|
||||
.commit()?;
|
||||
let mut model = OWLv2::new(config)?;
|
||||
|
@ -31,8 +31,8 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// build model
|
||||
let config = config
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = RMBG::new(config)?;
|
||||
|
||||
|
@ -28,9 +28,9 @@ fn main() -> Result<()> {
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
// Build model
|
||||
let config = match args.kind.as_str().try_into()? {
|
||||
let config = match args.kind.parse()? {
|
||||
SamKind::Sam => Config::sam_v1_base(),
|
||||
SamKind::Sam2 => match args.scale.as_str().try_into()? {
|
||||
SamKind::Sam2 => match args.scale.parse()? {
|
||||
Scale::T => Config::sam2_tiny(),
|
||||
Scale::S => Config::sam2_small(),
|
||||
Scale::B => Config::sam2_base_plus(),
|
||||
@ -40,7 +40,7 @@ fn main() -> Result<()> {
|
||||
SamKind::SamHq => Config::sam_hq_tiny(),
|
||||
SamKind::EdgeSam => Config::edge_sam_3x(),
|
||||
}
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
|
||||
let mut model = SAM::new(config)?;
|
||||
|
@ -25,14 +25,14 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// Build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::T => Config::sam2_1_tiny(),
|
||||
Scale::S => Config::sam2_1_small(),
|
||||
Scale::B => Config::sam2_1_base_plus(),
|
||||
Scale::L => Config::sam2_1_large(),
|
||||
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
|
||||
}
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = SAM2::new(config)?;
|
||||
|
||||
|
@ -18,7 +18,7 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
// build
|
||||
let config = Config::sapiens_seg_0_3b()
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = Sapiens::new(config)?;
|
||||
|
||||
|
@ -27,8 +27,8 @@ fn main() -> Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::slanet_lcnet_v2_mobile_ch()
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.commit()?;
|
||||
let mut model = SLANet::new(config)?;
|
||||
|
||||
|
@ -29,12 +29,12 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::Million(256.) => Config::smolvlm_256m(),
|
||||
Scale::Million(500.) => Config::smolvlm_500m(),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = SmolVLM::new(config)?;
|
||||
|
||||
|
@ -32,8 +32,8 @@ fn main() -> Result<()> {
|
||||
// ppocr_rec_v4_en()
|
||||
// repsvtr_ch()
|
||||
.with_model_ixx(0, 3, args.max_text_length.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.commit()?;
|
||||
let mut model = SVTR::new(config)?;
|
||||
|
||||
|
@ -38,19 +38,19 @@ fn main() -> anyhow::Result<()> {
|
||||
])?;
|
||||
|
||||
// build model
|
||||
let config = match args.scale.as_str().try_into()? {
|
||||
Scale::S => match args.kind.as_str().try_into()? {
|
||||
let config = match args.scale.parse()? {
|
||||
Scale::S => match args.kind.parse()? {
|
||||
TrOCRKind::Printed => Config::trocr_small_printed(),
|
||||
TrOCRKind::HandWritten => Config::trocr_small_handwritten(),
|
||||
},
|
||||
Scale::B => match args.kind.as_str().try_into()? {
|
||||
Scale::B => match args.kind.parse()? {
|
||||
TrOCRKind::Printed => Config::trocr_base_printed(),
|
||||
TrOCRKind::HandWritten => Config::trocr_base_handwritten(),
|
||||
},
|
||||
x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x),
|
||||
}
|
||||
.with_device_all(args.device.as_str().try_into()?)
|
||||
.with_dtype_all(args.dtype.as_str().try_into()?)
|
||||
.with_device_all(args.device.parse()?)
|
||||
.with_dtype_all(args.dtype.parse()?)
|
||||
.commit()?;
|
||||
|
||||
let mut model = TrOCR::new(config)?;
|
||||
|
@ -23,8 +23,8 @@ fn main() -> Result<()> {
|
||||
|
||||
// build model
|
||||
let config = Config::ultralytics_rtdetr_l()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut model = YOLO::new(config)?;
|
||||
|
||||
|
@ -27,7 +27,7 @@ fn main() -> Result<()> {
|
||||
let options_yolo = Config::yolo_detect()
|
||||
.with_scale(Scale::N)
|
||||
.with_version(8.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.commit()?;
|
||||
let mut yolo = YOLO::new(options_yolo)?;
|
||||
|
||||
|
@ -132,12 +132,12 @@ fn main() -> Result<()> {
|
||||
let args: Args = argh::from_env();
|
||||
let mut config = Config::yolo()
|
||||
.with_model_file(&args.model.unwrap_or_default())
|
||||
.with_task(args.task.as_str().try_into()?)
|
||||
.with_task(args.task.parse()?)
|
||||
.with_version(args.ver.try_into()?)
|
||||
.with_scale(args.scale.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_trt_fp16(args.trt_fp16)
|
||||
.with_scale(args.scale.parse()?)
|
||||
.with_model_dtype(args.dtype.parse()?)
|
||||
.with_model_device(args.device.parse()?)
|
||||
.with_model_tensorrt_fp16(args.trt_fp16)
|
||||
.with_model_ixx(
|
||||
0,
|
||||
0,
|
||||
|
@ -28,8 +28,8 @@ fn main() -> Result<()> {
|
||||
// yoloe_11s_seg_pf()
|
||||
// yoloe_11m_seg_pf()
|
||||
// yoloe_11l_seg_pf()
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_model_dtype(args.dtype.as_str().parse()?)
|
||||
.with_model_device(args.device.as_str().parse()?)
|
||||
.commit()?;
|
||||
let mut model = YOLO::new(config)?;
|
||||
|
||||
|
@ -66,7 +66,6 @@ pub struct Engine {
|
||||
pub file: String,
|
||||
pub spec: String,
|
||||
pub device: Device,
|
||||
pub trt_fp16: bool,
|
||||
#[args(inc)]
|
||||
pub iiixs: Vec<Iiix>,
|
||||
#[args(aka = "parameters")]
|
||||
@ -77,7 +76,50 @@ pub struct Engine {
|
||||
pub onnx: Option<OnnxIo>,
|
||||
pub ts: Ts,
|
||||
pub num_dry_run: usize,
|
||||
|
||||
// global
|
||||
pub graph_opt_level: Option<u8>,
|
||||
pub num_intra_threads: Option<usize>,
|
||||
pub num_inter_threads: Option<usize>,
|
||||
|
||||
// cpu
|
||||
pub cpu_arena_allocator: bool,
|
||||
|
||||
// tensorrt
|
||||
pub tensorrt_fp16: bool,
|
||||
pub tensorrt_engine_cache: bool,
|
||||
pub tensorrt_timing_cache: bool,
|
||||
|
||||
// openvino
|
||||
pub openvino_dynamic_shapes: bool,
|
||||
pub openvino_opencl_throttling: bool,
|
||||
pub openvino_qdq_optimizer: bool,
|
||||
pub openvino_num_threads: Option<usize>,
|
||||
|
||||
// onednn
|
||||
pub onednn_arena_allocator: bool,
|
||||
|
||||
// coreml
|
||||
pub coreml_static_input_shapes: bool,
|
||||
pub coreml_subgraph_running: bool,
|
||||
|
||||
// cann
|
||||
pub cann_graph_inference: bool,
|
||||
pub cann_dump_graphs: bool,
|
||||
pub cann_dump_om_model: bool,
|
||||
|
||||
// nnapi
|
||||
pub nnapi_cpu_only: bool,
|
||||
pub nnapi_disable_cpu: bool,
|
||||
pub nnapi_fp16: bool,
|
||||
pub nnapi_nchw: bool,
|
||||
|
||||
// armnn
|
||||
pub armnn_arena_allocator: bool,
|
||||
|
||||
// migraphx
|
||||
pub migraphx_fp16: bool,
|
||||
pub migraphx_exhaustive_tune: bool,
|
||||
}
|
||||
|
||||
impl Default for Engine {
|
||||
@ -85,7 +127,6 @@ impl Default for Engine {
|
||||
Self {
|
||||
file: Default::default(),
|
||||
device: Device::Cpu(0),
|
||||
trt_fp16: false,
|
||||
spec: Default::default(),
|
||||
iiixs: Default::default(),
|
||||
num_dry_run: 3,
|
||||
@ -94,7 +135,40 @@ impl Default for Engine {
|
||||
inputs_minoptmax: vec![],
|
||||
onnx: None,
|
||||
ts: Ts::default(),
|
||||
// global
|
||||
graph_opt_level: None,
|
||||
num_intra_threads: None,
|
||||
num_inter_threads: None,
|
||||
// cpu
|
||||
cpu_arena_allocator: true,
|
||||
// openvino
|
||||
openvino_dynamic_shapes: true,
|
||||
openvino_opencl_throttling: true,
|
||||
openvino_qdq_optimizer: true,
|
||||
openvino_num_threads: None,
|
||||
// onednn
|
||||
onednn_arena_allocator: true,
|
||||
// coreml
|
||||
coreml_static_input_shapes: false,
|
||||
coreml_subgraph_running: true,
|
||||
// tensorrt
|
||||
tensorrt_fp16: true,
|
||||
tensorrt_engine_cache: true,
|
||||
tensorrt_timing_cache: false,
|
||||
// cann
|
||||
cann_graph_inference: true,
|
||||
cann_dump_graphs: false,
|
||||
cann_dump_om_model: false,
|
||||
// nnapi
|
||||
nnapi_cpu_only: false,
|
||||
nnapi_disable_cpu: false,
|
||||
nnapi_fp16: true,
|
||||
nnapi_nchw: false,
|
||||
// armnn
|
||||
armnn_arena_allocator: true,
|
||||
// migraphx
|
||||
migraphx_fp16: true,
|
||||
migraphx_exhaustive_tune: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -106,9 +180,40 @@ impl Engine {
|
||||
spec: config.spec.clone(),
|
||||
iiixs: config.iiixs.clone(),
|
||||
device: config.device,
|
||||
trt_fp16: config.trt_fp16,
|
||||
num_dry_run: config.num_dry_run,
|
||||
// global
|
||||
graph_opt_level: config.graph_opt_level,
|
||||
num_intra_threads: config.num_intra_threads,
|
||||
num_inter_threads: config.num_inter_threads,
|
||||
// cpu
|
||||
cpu_arena_allocator: config.cpu_arena_allocator,
|
||||
// openvino
|
||||
openvino_dynamic_shapes: config.openvino_dynamic_shapes,
|
||||
openvino_opencl_throttling: config.openvino_opencl_throttling,
|
||||
openvino_qdq_optimizer: config.openvino_qdq_optimizer,
|
||||
openvino_num_threads: config.openvino_num_threads,
|
||||
// coreml
|
||||
coreml_static_input_shapes: config.coreml_static_input_shapes,
|
||||
coreml_subgraph_running: config.coreml_subgraph_running,
|
||||
// tensorrt
|
||||
tensorrt_fp16: config.tensorrt_fp16,
|
||||
tensorrt_engine_cache: config.tensorrt_engine_cache,
|
||||
tensorrt_timing_cache: config.tensorrt_timing_cache,
|
||||
// cann
|
||||
cann_graph_inference: config.cann_graph_inference,
|
||||
cann_dump_graphs: config.cann_dump_graphs,
|
||||
cann_dump_om_model: config.cann_dump_om_model,
|
||||
// nnapi
|
||||
nnapi_cpu_only: config.nnapi_cpu_only,
|
||||
nnapi_disable_cpu: config.nnapi_disable_cpu,
|
||||
nnapi_fp16: config.nnapi_fp16,
|
||||
nnapi_nchw: config.nnapi_nchw,
|
||||
// armnn
|
||||
armnn_arena_allocator: config.armnn_arena_allocator,
|
||||
// migraphx
|
||||
migraphx_fp16: config.migraphx_fp16,
|
||||
migraphx_exhaustive_tune: config.migraphx_exhaustive_tune,
|
||||
|
||||
..Default::default()
|
||||
}
|
||||
.build()
|
||||
@ -338,17 +443,20 @@ impl Engine {
|
||||
let compile_help = "Please compile ONNXRuntime with #EP";
|
||||
let feature_help = "#EP EP requires the features: `#FEATURE`. \
|
||||
\nConsider enabling them by passing, e.g., `--features #FEATURE`";
|
||||
let n_threads_available = std::thread::available_parallelism()
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
|
||||
match self.device {
|
||||
Device::TensorRt(id) => {
|
||||
#[cfg(not(feature = "trt"))]
|
||||
#[cfg(not(feature = "tensorrt"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "TensorRT")
|
||||
.replace("#FEATURE", "trt"));
|
||||
.replace("#FEATURE", "tensorrt"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "trt")]
|
||||
#[cfg(feature = "tensorrt")]
|
||||
{
|
||||
// generate shapes
|
||||
let mut spec_min = String::new();
|
||||
@ -379,13 +487,16 @@ impl Engine {
|
||||
spec_max += &s_max;
|
||||
}
|
||||
|
||||
let p = crate::Dir::Cache.crate_dir_default_with_subs(&["trt-cache"])?;
|
||||
let ep = ort::execution_providers::TensorRTExecutionProvider::default()
|
||||
.with_device_id(id as i32)
|
||||
.with_fp16(self.trt_fp16)
|
||||
.with_engine_cache(true)
|
||||
.with_engine_cache_path(p.to_str().unwrap())
|
||||
.with_timing_cache(false)
|
||||
.with_fp16(self.tensorrt_fp16)
|
||||
.with_engine_cache(self.tensorrt_engine_cache)
|
||||
.with_timing_cache(self.tensorrt_timing_cache)
|
||||
.with_engine_cache_path(
|
||||
crate::Dir::Cache
|
||||
.crate_dir_default_with_subs(&["caches", "tensorrt"])?
|
||||
.display(),
|
||||
)
|
||||
.with_profile_min_shapes(spec_min)
|
||||
.with_profile_opt_shapes(spec_opt)
|
||||
.with_profile_max_shapes(spec_max);
|
||||
@ -427,7 +538,7 @@ impl Engine {
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::CoreMl(id) => {
|
||||
Device::CoreMl => {
|
||||
#[cfg(not(feature = "coreml"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
@ -439,12 +550,12 @@ impl Engine {
|
||||
let ep = ort::execution_providers::CoreMLExecutionProvider::default()
|
||||
.with_model_cache_dir(
|
||||
crate::Dir::Cache
|
||||
.crate_dir_default_with_subs(&["coreml-cache"])?
|
||||
.crate_dir_default_with_subs(&["caches", "coreml"])?
|
||||
.display(),
|
||||
)
|
||||
.with_static_input_shapes(self.coreml_static_input_shapes)
|
||||
.with_subgraphs(self.coreml_subgraph_running)
|
||||
.with_compute_units(ort::execution_providers::coreml::CoreMLComputeUnits::All)
|
||||
.with_static_input_shapes(false)
|
||||
.with_subgraphs(true)
|
||||
.with_model_format(ort::execution_providers::coreml::CoreMLModelFormat::MLProgram)
|
||||
.with_specialization_strategy(
|
||||
ort::execution_providers::coreml::CoreMLSpecializationStrategy::FastPrediction,
|
||||
@ -459,9 +570,345 @@ impl Engine {
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::OpenVino(dt) => {
|
||||
#[cfg(not(feature = "openvino"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "OpenVINO")
|
||||
.replace("#FEATURE", "openvino"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "openvino")]
|
||||
{
|
||||
let ep = ort::execution_providers::OpenVINOExecutionProvider::default()
|
||||
.with_device_type(dt)
|
||||
.with_num_threads(self.openvino_num_threads.unwrap_or(n_threads_available))
|
||||
.with_dynamic_shapes(self.openvino_dynamic_shapes)
|
||||
.with_opencl_throttling(self.openvino_opencl_throttling)
|
||||
.with_qdq_optimizer(self.openvino_qdq_optimizer)
|
||||
.with_cache_dir(
|
||||
crate::Dir::Cache
|
||||
.crate_dir_default_with_subs(&["caches", "openvino"])?
|
||||
.display()
|
||||
.to_string(),
|
||||
);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register OpenVINO: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "OpenVINO")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::DirectMl(id) => {
|
||||
#[cfg(not(feature = "directml"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "DirectML")
|
||||
.replace("#FEATURE", "directml"));
|
||||
}
|
||||
#[cfg(feature = "directml")]
|
||||
{
|
||||
let ep = ort::execution_providers::DirectMLExecutionProvider::default()
|
||||
.with_device_id(id as i32);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register DirectML: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "DirectML")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Xnnpack => {
|
||||
#[cfg(not(feature = "xnnpack"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "XNNPack")
|
||||
.replace("#FEATURE", "xnnpack"));
|
||||
}
|
||||
#[cfg(feature = "xnnpack")]
|
||||
{
|
||||
let ep = ort::execution_providers::XNNPACKExecutionProvider::default();
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register XNNPack: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "XNNPack")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Cann(id) => {
|
||||
#[cfg(not(feature = "cann"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "CANN")
|
||||
.replace("#FEATURE", "cann"));
|
||||
}
|
||||
#[cfg(feature = "cann")]
|
||||
{
|
||||
let ep = ort::execution_providers::CANNExecutionProvider::default()
|
||||
.with_device_id(id as i32)
|
||||
.with_cann_graph(self.cann_graph_inference)
|
||||
.with_dump_graphs(self.cann_dump_graphs)
|
||||
.with_dump_om_model(self.cann_dump_om_model);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register CANN: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "CANN")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::RkNpu => {
|
||||
#[cfg(not(feature = "rknpu"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "RKNPU")
|
||||
.replace("#FEATURE", "rknpu"));
|
||||
}
|
||||
#[cfg(feature = "rknpu")]
|
||||
{
|
||||
let ep = ort::execution_providers::RKNPUExecutionProvider::default();
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register RKNPU: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "RKNPU")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::OneDnn => {
|
||||
#[cfg(not(feature = "onednn"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "oneDNN")
|
||||
.replace("#FEATURE", "onednn"));
|
||||
}
|
||||
#[cfg(feature = "onednn")]
|
||||
{
|
||||
let ep = ort::execution_providers::OneDNNExecutionProvider::default()
|
||||
.with_arena_allocator(self.onednn_arena_allocator);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register oneDNN: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "oneDNN")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Acl => {
|
||||
#[cfg(not(feature = "acl"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "ArmACL")
|
||||
.replace("#FEATURE", "acl"));
|
||||
}
|
||||
#[cfg(feature = "acl")]
|
||||
{
|
||||
let ep = ort::execution_providers::ACLExecutionProvider::default()
|
||||
.with_fast_math(true);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register ArmACL: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "ArmACL")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Rocm(id) => {
|
||||
#[cfg(not(feature = "rocm"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "ROCm")
|
||||
.replace("#FEATURE", "rocm"));
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
{
|
||||
let ep = ort::execution_providers::ROCmExecutionProvider::default()
|
||||
.with_device_id(id as _);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register ROCm: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "ROCm")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::NnApi => {
|
||||
#[cfg(not(feature = "nnapi"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "NNAPI")
|
||||
.replace("#FEATURE", "nnapi"));
|
||||
}
|
||||
#[cfg(feature = "nnapi")]
|
||||
{
|
||||
let ep = ort::execution_providers::NNAPIExecutionProvider::default()
|
||||
.with_fp16(self.nnapi_fp16)
|
||||
.with_nchw(self.nnapi_nchw)
|
||||
.with_cpu_only(self.nnapi_cpu_only)
|
||||
.with_disable_cpu(self.nnapi_disable_cpu);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register NNAPI: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "NNAPI")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::ArmNn => {
|
||||
#[cfg(not(feature = "armnn"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "ArmNN")
|
||||
.replace("#FEATURE", "armnn"));
|
||||
}
|
||||
#[cfg(feature = "armnn")]
|
||||
{
|
||||
let ep = ort::execution_providers::ArmNNExecutionProvider::default()
|
||||
.with_arena_allocator(self.armnn_arena_allocator);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register ArmNN: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "ArmNN")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Tvm => {
|
||||
#[cfg(not(feature = "tvm"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "TVM")
|
||||
.replace("#FEATURE", "tvm"));
|
||||
}
|
||||
#[cfg(feature = "tvm")]
|
||||
{
|
||||
let ep = ort::execution_providers::TVMExecutionProvider::default();
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register TVM: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "TVM")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Qnn(id) => {
|
||||
#[cfg(not(feature = "qnn"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "QNN")
|
||||
.replace("#FEATURE", "qnn"));
|
||||
}
|
||||
#[cfg(feature = "qnn")]
|
||||
{
|
||||
let ep = ort::execution_providers::QNNExecutionProvider::default()
|
||||
.with_device_id(id as _);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register QNN: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "QNN")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::MiGraphX(id) => {
|
||||
#[cfg(not(feature = "migraphx"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "MIGraphX")
|
||||
.replace("#FEATURE", "migraphx"));
|
||||
}
|
||||
#[cfg(feature = "migraphx")]
|
||||
{
|
||||
let ep = ort::execution_providers::MIGraphXExecutionProvider::default()
|
||||
.with_device_id(id as _)
|
||||
.with_fp16(self.migraphx_fp16)
|
||||
.with_exhaustive_tune(self.migraphx_exhaustive_tune);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register MIGraphX: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "MIGraphX")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Vitis => {
|
||||
#[cfg(not(feature = "vitis"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "VitisAI")
|
||||
.replace("#FEATURE", "vitis"));
|
||||
}
|
||||
#[cfg(feature = "vitis")]
|
||||
{
|
||||
let ep = ort::execution_providers::VitisAIExecutionProvider::default()
|
||||
.with_cache_dir(
|
||||
crate::Dir::Cache
|
||||
.crate_dir_default_with_subs(&["caches", "vitis"])?
|
||||
.display()
|
||||
.to_string(),
|
||||
);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register VitisAI: {}", err)
|
||||
})?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "VitisAI")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Device::Azure => {
|
||||
#[cfg(not(feature = "azure"))]
|
||||
{
|
||||
anyhow::bail!(feature_help
|
||||
.replace("#EP", "Azure")
|
||||
.replace("#FEATURE", "azure"));
|
||||
}
|
||||
#[cfg(feature = "azure")]
|
||||
{
|
||||
let ep = ort::execution_providers::AzureExecutionProvider::default();
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder).map_err(|err| {
|
||||
anyhow::anyhow!("Failed to register Azure: {}", err)
|
||||
})?;
|
||||
builder = builder.with_extensions()?;
|
||||
}
|
||||
_ => anyhow::bail!(compile_help.replace("#EP", "Azure")),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let ep = ort::execution_providers::CPUExecutionProvider::default()
|
||||
.with_arena_allocator(true);
|
||||
.with_arena_allocator(self.cpu_arena_allocator);
|
||||
match ep.is_available() {
|
||||
Ok(true) => {
|
||||
ep.register(&mut builder)
|
||||
@ -481,7 +928,8 @@ impl Engine {
|
||||
};
|
||||
let session = builder
|
||||
.with_optimization_level(graph_opt_level)?
|
||||
.with_intra_threads(std::thread::available_parallelism()?.get())?
|
||||
.with_intra_threads(self.num_intra_threads.unwrap_or(n_threads_available))?
|
||||
.with_inter_threads(self.num_inter_threads.unwrap_or(2))?
|
||||
.commit_from_file(self.file())?;
|
||||
|
||||
Ok(session)
|
||||
|
@ -5,6 +5,7 @@ use log::{info, warn};
|
||||
use rayon::prelude::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use std::sync::mpsc;
|
||||
#[cfg(feature = "video")]
|
||||
use video_rs::{Decoder, Url};
|
||||
@ -80,9 +81,10 @@ impl std::fmt::Debug for DataLoader {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for DataLoader {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(source: &str) -> Result<Self, Self::Error> {
|
||||
impl FromStr for DataLoader {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(source: &str) -> Result<Self, Self::Err> {
|
||||
Self::new(source)
|
||||
}
|
||||
}
|
||||
|
@ -460,7 +460,7 @@ impl Hub {
|
||||
fn cache_file(owner: &str, repo: &str) -> String {
|
||||
let safe_owner = owner.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||
let safe_repo = repo.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||
format!(".cache-releases-{}-{}.json", safe_owner, safe_repo)
|
||||
format!("releases-{}-{}.json", safe_owner, safe_repo)
|
||||
}
|
||||
|
||||
fn get_releases(
|
||||
@ -470,7 +470,9 @@ impl Hub {
|
||||
to: &Dir,
|
||||
ttl: &Duration,
|
||||
) -> Result<Vec<Release>> {
|
||||
let cache = to.crate_dir_default()?.join(Self::cache_file(owner, repo));
|
||||
let cache = to
|
||||
.crate_dir_default_with_subs(&["caches"])?
|
||||
.join(Self::cache_file(owner, repo));
|
||||
let is_file_expired = Self::is_file_expired(&cache, ttl)?;
|
||||
let body = if is_file_expired {
|
||||
let gh_api_release = format!(
|
||||
|
@ -2,6 +2,7 @@ use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use ndarray::{s, Axis};
|
||||
use rand::{prelude::*, rng};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y,
|
||||
@ -16,10 +17,10 @@ pub enum SamKind {
|
||||
EdgeSam,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for SamKind {
|
||||
type Error = anyhow::Error;
|
||||
impl FromStr for SamKind {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"sam" => Ok(Self::Sam),
|
||||
"sam2" => Ok(Self::Sam2),
|
||||
|
@ -2,6 +2,7 @@ use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use ndarray::{s, Axis};
|
||||
use rayon::prelude::*;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y};
|
||||
|
||||
@ -11,10 +12,10 @@ pub enum TrOCRKind {
|
||||
HandWritten,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for TrOCRKind {
|
||||
type Error = anyhow::Error;
|
||||
impl FromStr for TrOCRKind {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"printed" => Ok(Self::Printed),
|
||||
"handwritten" | "hand-written" => Ok(Self::HandWritten),
|
||||
|
@ -261,6 +261,507 @@ impl Config {
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_graph_opt_level_all(mut self, level: u8) -> Self {
|
||||
self.visual = self.visual.with_graph_opt_level(level);
|
||||
self.textual = self.textual.with_graph_opt_level(level);
|
||||
self.model = self.model.with_graph_opt_level(level);
|
||||
self.encoder = self.encoder.with_graph_opt_level(level);
|
||||
self.decoder = self.decoder.with_graph_opt_level(level);
|
||||
self.visual_encoder = self.visual_encoder.with_graph_opt_level(level);
|
||||
self.textual_encoder = self.textual_encoder.with_graph_opt_level(level);
|
||||
self.visual_decoder = self.visual_decoder.with_graph_opt_level(level);
|
||||
self.textual_decoder = self.textual_decoder.with_graph_opt_level(level);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_graph_opt_level(level);
|
||||
self.size_encoder = self.size_encoder.with_graph_opt_level(level);
|
||||
self.size_decoder = self.size_decoder.with_graph_opt_level(level);
|
||||
self.coord_encoder = self.coord_encoder.with_graph_opt_level(level);
|
||||
self.coord_decoder = self.coord_decoder.with_graph_opt_level(level);
|
||||
self.visual_projection = self.visual_projection.with_graph_opt_level(level);
|
||||
self.textual_projection = self.textual_projection.with_graph_opt_level(level);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_num_intra_threads_all(mut self, num_threads: usize) -> Self {
|
||||
self.visual = self.visual.with_num_intra_threads(num_threads);
|
||||
self.textual = self.textual.with_num_intra_threads(num_threads);
|
||||
self.model = self.model.with_num_intra_threads(num_threads);
|
||||
self.encoder = self.encoder.with_num_intra_threads(num_threads);
|
||||
self.decoder = self.decoder.with_num_intra_threads(num_threads);
|
||||
self.visual_encoder = self.visual_encoder.with_num_intra_threads(num_threads);
|
||||
self.textual_encoder = self.textual_encoder.with_num_intra_threads(num_threads);
|
||||
self.visual_decoder = self.visual_decoder.with_num_intra_threads(num_threads);
|
||||
self.textual_decoder = self.textual_decoder.with_num_intra_threads(num_threads);
|
||||
self.textual_decoder_merged = self
|
||||
.textual_decoder_merged
|
||||
.with_num_intra_threads(num_threads);
|
||||
self.size_encoder = self.size_encoder.with_num_intra_threads(num_threads);
|
||||
self.size_decoder = self.size_decoder.with_num_intra_threads(num_threads);
|
||||
self.coord_encoder = self.coord_encoder.with_num_intra_threads(num_threads);
|
||||
self.coord_decoder = self.coord_decoder.with_num_intra_threads(num_threads);
|
||||
self.visual_projection = self.visual_projection.with_num_intra_threads(num_threads);
|
||||
self.textual_projection = self.textual_projection.with_num_intra_threads(num_threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_num_inter_threads_all(mut self, num_threads: usize) -> Self {
|
||||
self.visual = self.visual.with_num_inter_threads(num_threads);
|
||||
self.textual = self.textual.with_num_inter_threads(num_threads);
|
||||
self.model = self.model.with_num_inter_threads(num_threads);
|
||||
self.encoder = self.encoder.with_num_inter_threads(num_threads);
|
||||
self.decoder = self.decoder.with_num_inter_threads(num_threads);
|
||||
self.visual_encoder = self.visual_encoder.with_num_inter_threads(num_threads);
|
||||
self.textual_encoder = self.textual_encoder.with_num_inter_threads(num_threads);
|
||||
self.visual_decoder = self.visual_decoder.with_num_inter_threads(num_threads);
|
||||
self.textual_decoder = self.textual_decoder.with_num_inter_threads(num_threads);
|
||||
self.textual_decoder_merged = self
|
||||
.textual_decoder_merged
|
||||
.with_num_inter_threads(num_threads);
|
||||
self.size_encoder = self.size_encoder.with_num_inter_threads(num_threads);
|
||||
self.size_decoder = self.size_decoder.with_num_inter_threads(num_threads);
|
||||
self.coord_encoder = self.coord_encoder.with_num_inter_threads(num_threads);
|
||||
self.coord_decoder = self.coord_decoder.with_num_inter_threads(num_threads);
|
||||
self.visual_projection = self.visual_projection.with_num_inter_threads(num_threads);
|
||||
self.textual_projection = self.textual_projection.with_num_inter_threads(num_threads);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cpu_arena_allocator_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_cpu_arena_allocator(x);
|
||||
self.textual = self.textual.with_cpu_arena_allocator(x);
|
||||
self.model = self.model.with_cpu_arena_allocator(x);
|
||||
self.encoder = self.encoder.with_cpu_arena_allocator(x);
|
||||
self.decoder = self.decoder.with_cpu_arena_allocator(x);
|
||||
self.visual_encoder = self.visual_encoder.with_cpu_arena_allocator(x);
|
||||
self.textual_encoder = self.textual_encoder.with_cpu_arena_allocator(x);
|
||||
self.visual_decoder = self.visual_decoder.with_cpu_arena_allocator(x);
|
||||
self.textual_decoder = self.textual_decoder.with_cpu_arena_allocator(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_cpu_arena_allocator(x);
|
||||
self.size_encoder = self.size_encoder.with_cpu_arena_allocator(x);
|
||||
self.size_decoder = self.size_decoder.with_cpu_arena_allocator(x);
|
||||
self.coord_encoder = self.coord_encoder.with_cpu_arena_allocator(x);
|
||||
self.coord_decoder = self.coord_decoder.with_cpu_arena_allocator(x);
|
||||
self.visual_projection = self.visual_projection.with_cpu_arena_allocator(x);
|
||||
self.textual_projection = self.textual_projection.with_cpu_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_openvino_dynamic_shapes_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_openvino_dynamic_shapes(x);
|
||||
self.textual = self.textual.with_openvino_dynamic_shapes(x);
|
||||
self.model = self.model.with_openvino_dynamic_shapes(x);
|
||||
self.encoder = self.encoder.with_openvino_dynamic_shapes(x);
|
||||
self.decoder = self.decoder.with_openvino_dynamic_shapes(x);
|
||||
self.visual_encoder = self.visual_encoder.with_openvino_dynamic_shapes(x);
|
||||
self.textual_encoder = self.textual_encoder.with_openvino_dynamic_shapes(x);
|
||||
self.visual_decoder = self.visual_decoder.with_openvino_dynamic_shapes(x);
|
||||
self.textual_decoder = self.textual_decoder.with_openvino_dynamic_shapes(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_dynamic_shapes(x);
|
||||
self.size_encoder = self.size_encoder.with_openvino_dynamic_shapes(x);
|
||||
self.size_decoder = self.size_decoder.with_openvino_dynamic_shapes(x);
|
||||
self.coord_encoder = self.coord_encoder.with_openvino_dynamic_shapes(x);
|
||||
self.coord_decoder = self.coord_decoder.with_openvino_dynamic_shapes(x);
|
||||
self.visual_projection = self.visual_projection.with_openvino_dynamic_shapes(x);
|
||||
self.textual_projection = self.textual_projection.with_openvino_dynamic_shapes(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_openvino_opencl_throttling_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_openvino_opencl_throttling(x);
|
||||
self.textual = self.textual.with_openvino_opencl_throttling(x);
|
||||
self.model = self.model.with_openvino_opencl_throttling(x);
|
||||
self.encoder = self.encoder.with_openvino_opencl_throttling(x);
|
||||
self.decoder = self.decoder.with_openvino_opencl_throttling(x);
|
||||
self.visual_encoder = self.visual_encoder.with_openvino_opencl_throttling(x);
|
||||
self.textual_encoder = self.textual_encoder.with_openvino_opencl_throttling(x);
|
||||
self.visual_decoder = self.visual_decoder.with_openvino_opencl_throttling(x);
|
||||
self.textual_decoder = self.textual_decoder.with_openvino_opencl_throttling(x);
|
||||
self.textual_decoder_merged = self
|
||||
.textual_decoder_merged
|
||||
.with_openvino_opencl_throttling(x);
|
||||
self.size_encoder = self.size_encoder.with_openvino_opencl_throttling(x);
|
||||
self.size_decoder = self.size_decoder.with_openvino_opencl_throttling(x);
|
||||
self.coord_encoder = self.coord_encoder.with_openvino_opencl_throttling(x);
|
||||
self.coord_decoder = self.coord_decoder.with_openvino_opencl_throttling(x);
|
||||
self.visual_projection = self.visual_projection.with_openvino_opencl_throttling(x);
|
||||
self.textual_projection = self.textual_projection.with_openvino_opencl_throttling(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_openvino_qdq_optimizer_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_openvino_qdq_optimizer(x);
|
||||
self.textual = self.textual.with_openvino_qdq_optimizer(x);
|
||||
self.model = self.model.with_openvino_qdq_optimizer(x);
|
||||
self.encoder = self.encoder.with_openvino_qdq_optimizer(x);
|
||||
self.decoder = self.decoder.with_openvino_qdq_optimizer(x);
|
||||
self.visual_encoder = self.visual_encoder.with_openvino_qdq_optimizer(x);
|
||||
self.textual_encoder = self.textual_encoder.with_openvino_qdq_optimizer(x);
|
||||
self.visual_decoder = self.visual_decoder.with_openvino_qdq_optimizer(x);
|
||||
self.textual_decoder = self.textual_decoder.with_openvino_qdq_optimizer(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_openvino_qdq_optimizer(x);
|
||||
self.size_encoder = self.size_encoder.with_openvino_qdq_optimizer(x);
|
||||
self.size_decoder = self.size_decoder.with_openvino_qdq_optimizer(x);
|
||||
self.coord_encoder = self.coord_encoder.with_openvino_qdq_optimizer(x);
|
||||
self.coord_decoder = self.coord_decoder.with_openvino_qdq_optimizer(x);
|
||||
self.visual_projection = self.visual_projection.with_openvino_qdq_optimizer(x);
|
||||
self.textual_projection = self.textual_projection.with_openvino_qdq_optimizer(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_openvino_num_threads_all(mut self, num_threads: usize) -> Self {
|
||||
self.visual = self.visual.with_openvino_num_threads(num_threads);
|
||||
self.textual = self.textual.with_openvino_num_threads(num_threads);
|
||||
self.model = self.model.with_openvino_num_threads(num_threads);
|
||||
self.encoder = self.encoder.with_openvino_num_threads(num_threads);
|
||||
self.decoder = self.decoder.with_openvino_num_threads(num_threads);
|
||||
self.visual_encoder = self.visual_encoder.with_openvino_num_threads(num_threads);
|
||||
self.textual_encoder = self.textual_encoder.with_openvino_num_threads(num_threads);
|
||||
self.visual_decoder = self.visual_decoder.with_openvino_num_threads(num_threads);
|
||||
self.textual_decoder = self.textual_decoder.with_openvino_num_threads(num_threads);
|
||||
self.textual_decoder_merged = self
|
||||
.textual_decoder_merged
|
||||
.with_openvino_num_threads(num_threads);
|
||||
self.size_encoder = self.size_encoder.with_openvino_num_threads(num_threads);
|
||||
self.size_decoder = self.size_decoder.with_openvino_num_threads(num_threads);
|
||||
self.coord_encoder = self.coord_encoder.with_openvino_num_threads(num_threads);
|
||||
self.coord_decoder = self.coord_decoder.with_openvino_num_threads(num_threads);
|
||||
self.visual_projection = self
|
||||
.visual_projection
|
||||
.with_openvino_num_threads(num_threads);
|
||||
self.textual_projection = self
|
||||
.textual_projection
|
||||
.with_openvino_num_threads(num_threads);
|
||||
self
|
||||
}
|
||||
|
||||
// onednn
|
||||
pub fn with_onednn_arena_allocator_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_onednn_arena_allocator(x);
|
||||
self.textual = self.textual.with_onednn_arena_allocator(x);
|
||||
self.model = self.model.with_onednn_arena_allocator(x);
|
||||
self.encoder = self.encoder.with_onednn_arena_allocator(x);
|
||||
self.decoder = self.decoder.with_onednn_arena_allocator(x);
|
||||
self.visual_encoder = self.visual_encoder.with_onednn_arena_allocator(x);
|
||||
self.textual_encoder = self.textual_encoder.with_onednn_arena_allocator(x);
|
||||
self.visual_decoder = self.visual_decoder.with_onednn_arena_allocator(x);
|
||||
self.textual_decoder = self.textual_decoder.with_onednn_arena_allocator(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_onednn_arena_allocator(x);
|
||||
self.size_encoder = self.size_encoder.with_onednn_arena_allocator(x);
|
||||
self.size_decoder = self.size_decoder.with_onednn_arena_allocator(x);
|
||||
self.coord_encoder = self.coord_encoder.with_onednn_arena_allocator(x);
|
||||
self.coord_decoder = self.coord_decoder.with_onednn_arena_allocator(x);
|
||||
self.visual_projection = self.visual_projection.with_onednn_arena_allocator(x);
|
||||
self.textual_projection = self.textual_projection.with_onednn_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
|
||||
// tensorrt
|
||||
pub fn with_tensorrt_fp16_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_tensorrt_fp16(x);
|
||||
self.textual = self.textual.with_tensorrt_fp16(x);
|
||||
self.model = self.model.with_tensorrt_fp16(x);
|
||||
self.encoder = self.encoder.with_tensorrt_fp16(x);
|
||||
self.decoder = self.decoder.with_tensorrt_fp16(x);
|
||||
self.visual_encoder = self.visual_encoder.with_tensorrt_fp16(x);
|
||||
self.textual_encoder = self.textual_encoder.with_tensorrt_fp16(x);
|
||||
self.visual_decoder = self.visual_decoder.with_tensorrt_fp16(x);
|
||||
self.textual_decoder = self.textual_decoder.with_tensorrt_fp16(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_fp16(x);
|
||||
self.size_encoder = self.size_encoder.with_tensorrt_fp16(x);
|
||||
self.size_decoder = self.size_decoder.with_tensorrt_fp16(x);
|
||||
self.coord_encoder = self.coord_encoder.with_tensorrt_fp16(x);
|
||||
self.coord_decoder = self.coord_decoder.with_tensorrt_fp16(x);
|
||||
self.visual_projection = self.visual_projection.with_tensorrt_fp16(x);
|
||||
self.textual_projection = self.textual_projection.with_tensorrt_fp16(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tensorrt_engine_cache_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_tensorrt_engine_cache(x);
|
||||
self.textual = self.textual.with_tensorrt_engine_cache(x);
|
||||
self.model = self.model.with_tensorrt_engine_cache(x);
|
||||
self.encoder = self.encoder.with_tensorrt_engine_cache(x);
|
||||
self.decoder = self.decoder.with_tensorrt_engine_cache(x);
|
||||
self.visual_encoder = self.visual_encoder.with_tensorrt_engine_cache(x);
|
||||
self.textual_encoder = self.textual_encoder.with_tensorrt_engine_cache(x);
|
||||
self.visual_decoder = self.visual_decoder.with_tensorrt_engine_cache(x);
|
||||
self.textual_decoder = self.textual_decoder.with_tensorrt_engine_cache(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_engine_cache(x);
|
||||
self.size_encoder = self.size_encoder.with_tensorrt_engine_cache(x);
|
||||
self.size_decoder = self.size_decoder.with_tensorrt_engine_cache(x);
|
||||
self.coord_encoder = self.coord_encoder.with_tensorrt_engine_cache(x);
|
||||
self.coord_decoder = self.coord_decoder.with_tensorrt_engine_cache(x);
|
||||
self.visual_projection = self.visual_projection.with_tensorrt_engine_cache(x);
|
||||
self.textual_projection = self.textual_projection.with_tensorrt_engine_cache(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tensorrt_timing_cache_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_tensorrt_timing_cache(x);
|
||||
self.textual = self.textual.with_tensorrt_timing_cache(x);
|
||||
self.model = self.model.with_tensorrt_timing_cache(x);
|
||||
self.encoder = self.encoder.with_tensorrt_timing_cache(x);
|
||||
self.decoder = self.decoder.with_tensorrt_timing_cache(x);
|
||||
self.visual_encoder = self.visual_encoder.with_tensorrt_timing_cache(x);
|
||||
self.textual_encoder = self.textual_encoder.with_tensorrt_timing_cache(x);
|
||||
self.visual_decoder = self.visual_decoder.with_tensorrt_timing_cache(x);
|
||||
self.textual_decoder = self.textual_decoder.with_tensorrt_timing_cache(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_tensorrt_timing_cache(x);
|
||||
self.size_encoder = self.size_encoder.with_tensorrt_timing_cache(x);
|
||||
self.size_decoder = self.size_decoder.with_tensorrt_timing_cache(x);
|
||||
self.coord_encoder = self.coord_encoder.with_tensorrt_timing_cache(x);
|
||||
self.coord_decoder = self.coord_decoder.with_tensorrt_timing_cache(x);
|
||||
self.visual_projection = self.visual_projection.with_tensorrt_timing_cache(x);
|
||||
self.textual_projection = self.textual_projection.with_tensorrt_timing_cache(x);
|
||||
self
|
||||
}
|
||||
|
||||
// coreml
|
||||
pub fn with_coreml_static_input_shapes_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_coreml_static_input_shapes(x);
|
||||
self.textual = self.textual.with_coreml_static_input_shapes(x);
|
||||
self.model = self.model.with_coreml_static_input_shapes(x);
|
||||
self.encoder = self.encoder.with_coreml_static_input_shapes(x);
|
||||
self.decoder = self.decoder.with_coreml_static_input_shapes(x);
|
||||
self.visual_encoder = self.visual_encoder.with_coreml_static_input_shapes(x);
|
||||
self.textual_encoder = self.textual_encoder.with_coreml_static_input_shapes(x);
|
||||
self.visual_decoder = self.visual_decoder.with_coreml_static_input_shapes(x);
|
||||
self.textual_decoder = self.textual_decoder.with_coreml_static_input_shapes(x);
|
||||
self.textual_decoder_merged = self
|
||||
.textual_decoder_merged
|
||||
.with_coreml_static_input_shapes(x);
|
||||
self.size_encoder = self.size_encoder.with_coreml_static_input_shapes(x);
|
||||
self.size_decoder = self.size_decoder.with_coreml_static_input_shapes(x);
|
||||
self.coord_encoder = self.coord_encoder.with_coreml_static_input_shapes(x);
|
||||
self.coord_decoder = self.coord_decoder.with_coreml_static_input_shapes(x);
|
||||
self.visual_projection = self.visual_projection.with_coreml_static_input_shapes(x);
|
||||
self.textual_projection = self.textual_projection.with_coreml_static_input_shapes(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_coreml_subgraph_running_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_coreml_subgraph_running(x);
|
||||
self.textual = self.textual.with_coreml_subgraph_running(x);
|
||||
self.model = self.model.with_coreml_subgraph_running(x);
|
||||
self.encoder = self.encoder.with_coreml_subgraph_running(x);
|
||||
self.decoder = self.decoder.with_coreml_subgraph_running(x);
|
||||
self.visual_encoder = self.visual_encoder.with_coreml_subgraph_running(x);
|
||||
self.textual_encoder = self.textual_encoder.with_coreml_subgraph_running(x);
|
||||
self.visual_decoder = self.visual_decoder.with_coreml_subgraph_running(x);
|
||||
self.textual_decoder = self.textual_decoder.with_coreml_subgraph_running(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_coreml_subgraph_running(x);
|
||||
self.size_encoder = self.size_encoder.with_coreml_subgraph_running(x);
|
||||
self.size_decoder = self.size_decoder.with_coreml_subgraph_running(x);
|
||||
self.coord_encoder = self.coord_encoder.with_coreml_subgraph_running(x);
|
||||
self.coord_decoder = self.coord_decoder.with_coreml_subgraph_running(x);
|
||||
self.visual_projection = self.visual_projection.with_coreml_subgraph_running(x);
|
||||
self.textual_projection = self.textual_projection.with_coreml_subgraph_running(x);
|
||||
self
|
||||
}
|
||||
|
||||
// cann
|
||||
pub fn with_cann_graph_inference_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_cann_graph_inference(x);
|
||||
self.textual = self.textual.with_cann_graph_inference(x);
|
||||
self.model = self.model.with_cann_graph_inference(x);
|
||||
self.encoder = self.encoder.with_cann_graph_inference(x);
|
||||
self.decoder = self.decoder.with_cann_graph_inference(x);
|
||||
self.visual_encoder = self.visual_encoder.with_cann_graph_inference(x);
|
||||
self.textual_encoder = self.textual_encoder.with_cann_graph_inference(x);
|
||||
self.visual_decoder = self.visual_decoder.with_cann_graph_inference(x);
|
||||
self.textual_decoder = self.textual_decoder.with_cann_graph_inference(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_graph_inference(x);
|
||||
self.size_encoder = self.size_encoder.with_cann_graph_inference(x);
|
||||
self.size_decoder = self.size_decoder.with_cann_graph_inference(x);
|
||||
self.coord_encoder = self.coord_encoder.with_cann_graph_inference(x);
|
||||
self.coord_decoder = self.coord_decoder.with_cann_graph_inference(x);
|
||||
self.visual_projection = self.visual_projection.with_cann_graph_inference(x);
|
||||
self.textual_projection = self.textual_projection.with_cann_graph_inference(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cann_dump_graphs_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_cann_dump_graphs(x);
|
||||
self.textual = self.textual.with_cann_dump_graphs(x);
|
||||
self.model = self.model.with_cann_dump_graphs(x);
|
||||
self.encoder = self.encoder.with_cann_dump_graphs(x);
|
||||
self.decoder = self.decoder.with_cann_dump_graphs(x);
|
||||
self.visual_encoder = self.visual_encoder.with_cann_dump_graphs(x);
|
||||
self.textual_encoder = self.textual_encoder.with_cann_dump_graphs(x);
|
||||
self.visual_decoder = self.visual_decoder.with_cann_dump_graphs(x);
|
||||
self.textual_decoder = self.textual_decoder.with_cann_dump_graphs(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_graphs(x);
|
||||
self.size_encoder = self.size_encoder.with_cann_dump_graphs(x);
|
||||
self.size_decoder = self.size_decoder.with_cann_dump_graphs(x);
|
||||
self.coord_encoder = self.coord_encoder.with_cann_dump_graphs(x);
|
||||
self.coord_decoder = self.coord_decoder.with_cann_dump_graphs(x);
|
||||
self.visual_projection = self.visual_projection.with_cann_dump_graphs(x);
|
||||
self.textual_projection = self.textual_projection.with_cann_dump_graphs(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cann_dump_om_model_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_cann_dump_om_model(x);
|
||||
self.textual = self.textual.with_cann_dump_om_model(x);
|
||||
self.model = self.model.with_cann_dump_om_model(x);
|
||||
self.encoder = self.encoder.with_cann_dump_om_model(x);
|
||||
self.decoder = self.decoder.with_cann_dump_om_model(x);
|
||||
self.visual_encoder = self.visual_encoder.with_cann_dump_om_model(x);
|
||||
self.textual_encoder = self.textual_encoder.with_cann_dump_om_model(x);
|
||||
self.visual_decoder = self.visual_decoder.with_cann_dump_om_model(x);
|
||||
self.textual_decoder = self.textual_decoder.with_cann_dump_om_model(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_cann_dump_om_model(x);
|
||||
self.size_encoder = self.size_encoder.with_cann_dump_om_model(x);
|
||||
self.size_decoder = self.size_decoder.with_cann_dump_om_model(x);
|
||||
self.coord_encoder = self.coord_encoder.with_cann_dump_om_model(x);
|
||||
self.coord_decoder = self.coord_decoder.with_cann_dump_om_model(x);
|
||||
self.visual_projection = self.visual_projection.with_cann_dump_om_model(x);
|
||||
self.textual_projection = self.textual_projection.with_cann_dump_om_model(x);
|
||||
self
|
||||
}
|
||||
|
||||
// nnapi
|
||||
pub fn with_nnapi_cpu_only_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_nnapi_cpu_only(x);
|
||||
self.textual = self.textual.with_nnapi_cpu_only(x);
|
||||
self.model = self.model.with_nnapi_cpu_only(x);
|
||||
self.encoder = self.encoder.with_nnapi_cpu_only(x);
|
||||
self.decoder = self.decoder.with_nnapi_cpu_only(x);
|
||||
self.visual_encoder = self.visual_encoder.with_nnapi_cpu_only(x);
|
||||
self.textual_encoder = self.textual_encoder.with_nnapi_cpu_only(x);
|
||||
self.visual_decoder = self.visual_decoder.with_nnapi_cpu_only(x);
|
||||
self.textual_decoder = self.textual_decoder.with_nnapi_cpu_only(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_cpu_only(x);
|
||||
self.size_encoder = self.size_encoder.with_nnapi_cpu_only(x);
|
||||
self.size_decoder = self.size_decoder.with_nnapi_cpu_only(x);
|
||||
self.coord_encoder = self.coord_encoder.with_nnapi_cpu_only(x);
|
||||
self.coord_decoder = self.coord_decoder.with_nnapi_cpu_only(x);
|
||||
self.visual_projection = self.visual_projection.with_nnapi_cpu_only(x);
|
||||
self.textual_projection = self.textual_projection.with_nnapi_cpu_only(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nnapi_disable_cpu_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_nnapi_disable_cpu(x);
|
||||
self.textual = self.textual.with_nnapi_disable_cpu(x);
|
||||
self.model = self.model.with_nnapi_disable_cpu(x);
|
||||
self.encoder = self.encoder.with_nnapi_disable_cpu(x);
|
||||
self.decoder = self.decoder.with_nnapi_disable_cpu(x);
|
||||
self.visual_encoder = self.visual_encoder.with_nnapi_disable_cpu(x);
|
||||
self.textual_encoder = self.textual_encoder.with_nnapi_disable_cpu(x);
|
||||
self.visual_decoder = self.visual_decoder.with_nnapi_disable_cpu(x);
|
||||
self.textual_decoder = self.textual_decoder.with_nnapi_disable_cpu(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_disable_cpu(x);
|
||||
self.size_encoder = self.size_encoder.with_nnapi_disable_cpu(x);
|
||||
self.size_decoder = self.size_decoder.with_nnapi_disable_cpu(x);
|
||||
self.coord_encoder = self.coord_encoder.with_nnapi_disable_cpu(x);
|
||||
self.coord_decoder = self.coord_decoder.with_nnapi_disable_cpu(x);
|
||||
self.visual_projection = self.visual_projection.with_nnapi_disable_cpu(x);
|
||||
self.textual_projection = self.textual_projection.with_nnapi_disable_cpu(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nnapi_fp16_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_nnapi_fp16(x);
|
||||
self.textual = self.textual.with_nnapi_fp16(x);
|
||||
self.model = self.model.with_nnapi_fp16(x);
|
||||
self.encoder = self.encoder.with_nnapi_fp16(x);
|
||||
self.decoder = self.decoder.with_nnapi_fp16(x);
|
||||
self.visual_encoder = self.visual_encoder.with_nnapi_fp16(x);
|
||||
self.textual_encoder = self.textual_encoder.with_nnapi_fp16(x);
|
||||
self.visual_decoder = self.visual_decoder.with_nnapi_fp16(x);
|
||||
self.textual_decoder = self.textual_decoder.with_nnapi_fp16(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_fp16(x);
|
||||
self.size_encoder = self.size_encoder.with_nnapi_fp16(x);
|
||||
self.size_decoder = self.size_decoder.with_nnapi_fp16(x);
|
||||
self.coord_encoder = self.coord_encoder.with_nnapi_fp16(x);
|
||||
self.coord_decoder = self.coord_decoder.with_nnapi_fp16(x);
|
||||
self.visual_projection = self.visual_projection.with_nnapi_fp16(x);
|
||||
self.textual_projection = self.textual_projection.with_nnapi_fp16(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_nnapi_nchw_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_nnapi_nchw(x);
|
||||
self.textual = self.textual.with_nnapi_nchw(x);
|
||||
self.model = self.model.with_nnapi_nchw(x);
|
||||
self.encoder = self.encoder.with_nnapi_nchw(x);
|
||||
self.decoder = self.decoder.with_nnapi_nchw(x);
|
||||
self.visual_encoder = self.visual_encoder.with_nnapi_nchw(x);
|
||||
self.textual_encoder = self.textual_encoder.with_nnapi_nchw(x);
|
||||
self.visual_decoder = self.visual_decoder.with_nnapi_nchw(x);
|
||||
self.textual_decoder = self.textual_decoder.with_nnapi_nchw(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_nnapi_nchw(x);
|
||||
self.size_encoder = self.size_encoder.with_nnapi_nchw(x);
|
||||
self.size_decoder = self.size_decoder.with_nnapi_nchw(x);
|
||||
self.coord_encoder = self.coord_encoder.with_nnapi_nchw(x);
|
||||
self.coord_decoder = self.coord_decoder.with_nnapi_nchw(x);
|
||||
self.visual_projection = self.visual_projection.with_nnapi_nchw(x);
|
||||
self.textual_projection = self.textual_projection.with_nnapi_nchw(x);
|
||||
self
|
||||
}
|
||||
|
||||
// armnn
|
||||
pub fn with_armnn_arena_allocator_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_armnn_arena_allocator(x);
|
||||
self.textual = self.textual.with_armnn_arena_allocator(x);
|
||||
self.model = self.model.with_armnn_arena_allocator(x);
|
||||
self.encoder = self.encoder.with_armnn_arena_allocator(x);
|
||||
self.decoder = self.decoder.with_armnn_arena_allocator(x);
|
||||
self.visual_encoder = self.visual_encoder.with_armnn_arena_allocator(x);
|
||||
self.textual_encoder = self.textual_encoder.with_armnn_arena_allocator(x);
|
||||
self.visual_decoder = self.visual_decoder.with_armnn_arena_allocator(x);
|
||||
self.textual_decoder = self.textual_decoder.with_armnn_arena_allocator(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_armnn_arena_allocator(x);
|
||||
self.size_encoder = self.size_encoder.with_armnn_arena_allocator(x);
|
||||
self.size_decoder = self.size_decoder.with_armnn_arena_allocator(x);
|
||||
self.coord_encoder = self.coord_encoder.with_armnn_arena_allocator(x);
|
||||
self.coord_decoder = self.coord_decoder.with_armnn_arena_allocator(x);
|
||||
self.visual_projection = self.visual_projection.with_armnn_arena_allocator(x);
|
||||
self.textual_projection = self.textual_projection.with_armnn_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
|
||||
// migraphx
|
||||
pub fn with_migraphx_fp16_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_migraphx_fp16(x);
|
||||
self.textual = self.textual.with_migraphx_fp16(x);
|
||||
self.model = self.model.with_migraphx_fp16(x);
|
||||
self.encoder = self.encoder.with_migraphx_fp16(x);
|
||||
self.decoder = self.decoder.with_migraphx_fp16(x);
|
||||
self.visual_encoder = self.visual_encoder.with_migraphx_fp16(x);
|
||||
self.textual_encoder = self.textual_encoder.with_migraphx_fp16(x);
|
||||
self.visual_decoder = self.visual_decoder.with_migraphx_fp16(x);
|
||||
self.textual_decoder = self.textual_decoder.with_migraphx_fp16(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_fp16(x);
|
||||
self.size_encoder = self.size_encoder.with_migraphx_fp16(x);
|
||||
self.size_decoder = self.size_decoder.with_migraphx_fp16(x);
|
||||
self.coord_encoder = self.coord_encoder.with_migraphx_fp16(x);
|
||||
self.coord_decoder = self.coord_decoder.with_migraphx_fp16(x);
|
||||
self.visual_projection = self.visual_projection.with_migraphx_fp16(x);
|
||||
self.textual_projection = self.textual_projection.with_migraphx_fp16(x);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_migraphx_exhaustive_tune_all(mut self, x: bool) -> Self {
|
||||
self.visual = self.visual.with_migraphx_exhaustive_tune(x);
|
||||
self.textual = self.textual.with_migraphx_exhaustive_tune(x);
|
||||
self.model = self.model.with_migraphx_exhaustive_tune(x);
|
||||
self.encoder = self.encoder.with_migraphx_exhaustive_tune(x);
|
||||
self.decoder = self.decoder.with_migraphx_exhaustive_tune(x);
|
||||
self.visual_encoder = self.visual_encoder.with_migraphx_exhaustive_tune(x);
|
||||
self.textual_encoder = self.textual_encoder.with_migraphx_exhaustive_tune(x);
|
||||
self.visual_decoder = self.visual_decoder.with_migraphx_exhaustive_tune(x);
|
||||
self.textual_decoder = self.textual_decoder.with_migraphx_exhaustive_tune(x);
|
||||
self.textual_decoder_merged = self.textual_decoder_merged.with_migraphx_exhaustive_tune(x);
|
||||
self.size_encoder = self.size_encoder.with_migraphx_exhaustive_tune(x);
|
||||
self.size_decoder = self.size_decoder.with_migraphx_exhaustive_tune(x);
|
||||
self.coord_encoder = self.coord_encoder.with_migraphx_exhaustive_tune(x);
|
||||
self.coord_decoder = self.coord_decoder.with_migraphx_exhaustive_tune(x);
|
||||
self.visual_projection = self.visual_projection.with_migraphx_exhaustive_tune(x);
|
||||
self.textual_projection = self.textual_projection.with_migraphx_exhaustive_tune(x);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl_ort_config_methods!(Config, model);
|
||||
|
@ -3,7 +3,22 @@ pub enum Device {
|
||||
Cpu(usize),
|
||||
Cuda(usize),
|
||||
TensorRt(usize),
|
||||
CoreMl(usize),
|
||||
OpenVino(&'static str),
|
||||
DirectMl(usize),
|
||||
Cann(usize),
|
||||
Rocm(usize),
|
||||
Qnn(usize),
|
||||
MiGraphX(usize),
|
||||
CoreMl,
|
||||
Xnnpack,
|
||||
RkNpu,
|
||||
OneDnn,
|
||||
Acl,
|
||||
NnApi,
|
||||
ArmNn,
|
||||
Tvm,
|
||||
Vitis,
|
||||
Azure,
|
||||
}
|
||||
|
||||
impl Default for Device {
|
||||
@ -18,41 +33,97 @@ impl std::fmt::Display for Device {
|
||||
Self::Cpu(i) => format!("CPU:{}", i),
|
||||
Self::Cuda(i) => format!("CUDA:{}(NVIDIA)", i),
|
||||
Self::TensorRt(i) => format!("TensorRT:{}(NVIDIA)", i),
|
||||
Self::CoreMl(i) => format!("CoreML:{}(Apple)", i),
|
||||
Self::Cann(i) => format!("CANN:{}(Huawei)", i),
|
||||
Self::OpenVino(s) => format!("OpenVINO:{}(Intel)", s),
|
||||
Self::DirectMl(i) => format!("DirectML:{}(Microsoft)", i),
|
||||
Self::Qnn(i) => format!("QNN:{}(Qualcomm)", i),
|
||||
Self::MiGraphX(i) => format!("MIGraphX:{}(AMD)", i),
|
||||
Self::Rocm(i) => format!("ROCm:{}(AMD)", i),
|
||||
Self::CoreMl => "CoreML(Apple)".to_string(),
|
||||
Self::Azure => "Azure(Microsoft)".to_string(),
|
||||
Self::Xnnpack => "XNNPACK".to_string(),
|
||||
Self::OneDnn => "oneDNN(Intel)".to_string(),
|
||||
Self::RkNpu => "RKNPU".to_string(),
|
||||
Self::Acl => "ACL(Arm)".to_string(),
|
||||
Self::NnApi => "NNAPI(Android)".to_string(),
|
||||
Self::ArmNn => "ArmNN(Arm)".to_string(),
|
||||
Self::Tvm => "TVM(Apache)".to_string(),
|
||||
Self::Vitis => "VitisAI(AMD)".to_string(),
|
||||
};
|
||||
write!(f, "{}", x)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Device {
|
||||
type Error = anyhow::Error;
|
||||
impl std::str::FromStr for Device {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
// device and its id
|
||||
let d_id: Vec<&str> = s.trim().split(':').collect();
|
||||
let (d, id) = match d_id.len() {
|
||||
1 => (d_id[0].trim(), 0),
|
||||
2 => (d_id[0].trim(), d_id[1].trim().parse::<usize>().unwrap_or(0)),
|
||||
_ => anyhow::bail!(
|
||||
"Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`"
|
||||
),
|
||||
};
|
||||
// TODO: device-id checking
|
||||
match d.to_lowercase().as_str() {
|
||||
"cpu" => Ok(Self::Cpu(id)),
|
||||
"cuda" => Ok(Self::Cuda(id)),
|
||||
"trt" | "tensorrt" => Ok(Self::TensorRt(id)),
|
||||
"coreml" | "mps" => Ok(Self::CoreMl(id)),
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
#[inline]
|
||||
fn parse_device_id(id_str: Option<&str>) -> usize {
|
||||
id_str
|
||||
.map(|s| s.trim().parse::<usize>().unwrap_or(0))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
// Use split_once for better performance - no Vec allocation
|
||||
let (device_type, id_part) = s
|
||||
.trim()
|
||||
.split_once(':')
|
||||
.map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
|
||||
|
||||
match device_type.to_lowercase().as_str() {
|
||||
"cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
|
||||
"cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
|
||||
"trt" | "tensorrt" => Ok(Self::TensorRt(parse_device_id(id_part))),
|
||||
"coreml" | "mps" => Ok(Self::CoreMl),
|
||||
"openvino" => {
|
||||
// For OpenVino, use the user input directly after first colon (trimmed)
|
||||
let device_spec = id_part.map(|s| s.trim()).unwrap_or("CPU"); // Default to CPU if no specification provided
|
||||
Ok(Self::OpenVino(Box::leak(
|
||||
device_spec.to_string().into_boxed_str(),
|
||||
)))
|
||||
}
|
||||
"directml" => Ok(Self::DirectMl(parse_device_id(id_part))),
|
||||
"xnnpack" => Ok(Self::Xnnpack),
|
||||
"cann" => Ok(Self::Cann(parse_device_id(id_part))),
|
||||
"rknpu" => Ok(Self::RkNpu),
|
||||
"onednn" => Ok(Self::OneDnn),
|
||||
"acl" => Ok(Self::Acl),
|
||||
"rocm" => Ok(Self::Rocm(parse_device_id(id_part))),
|
||||
"nnapi" => Ok(Self::NnApi),
|
||||
"armnn" => Ok(Self::ArmNn),
|
||||
"tvm" => Ok(Self::Tvm),
|
||||
"qnn" => Ok(Self::Qnn(parse_device_id(id_part))),
|
||||
"migraphx" => Ok(Self::MiGraphX(parse_device_id(id_part))),
|
||||
"vitisai" => Ok(Self::Vitis),
|
||||
"azure" => Ok(Self::Azure),
|
||||
_ => anyhow::bail!("Unsupported device str: {s:?}."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn id(&self) -> usize {
|
||||
pub fn id(&self) -> Option<usize> {
|
||||
match self {
|
||||
Self::Cpu(i) | Self::Cuda(i) | Self::TensorRt(i) | Self::CoreMl(i) => *i,
|
||||
Self::Cpu(i)
|
||||
| Self::Cuda(i)
|
||||
| Self::TensorRt(i)
|
||||
| Self::Cann(i)
|
||||
| Self::Qnn(i)
|
||||
| Self::Rocm(i)
|
||||
| Self::MiGraphX(i)
|
||||
| Self::DirectMl(i) => Some(*i),
|
||||
Self::OpenVino(_)
|
||||
| Self::Xnnpack
|
||||
| Self::CoreMl
|
||||
| Self::RkNpu
|
||||
| Self::OneDnn
|
||||
| Self::NnApi
|
||||
| Self::Azure
|
||||
| Self::Vitis
|
||||
| Self::ArmNn
|
||||
| Self::Tvm
|
||||
| Self::Acl => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,10 +29,10 @@ pub enum DType {
|
||||
Complex128,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for DType {
|
||||
type Error = anyhow::Error;
|
||||
impl std::str::FromStr for DType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"auto" | "dyn" => Ok(Self::Auto),
|
||||
"u4" | "uint4" => Ok(Self::Uint4),
|
||||
|
@ -9,10 +9,42 @@ pub struct ORTConfig {
|
||||
pub device: Device,
|
||||
pub iiixs: Vec<Iiix>,
|
||||
pub num_dry_run: usize,
|
||||
pub trt_fp16: bool,
|
||||
pub graph_opt_level: Option<u8>,
|
||||
pub spec: String, // TODO: move out
|
||||
pub dtype: DType, // For dynamically loading the model
|
||||
// global
|
||||
pub graph_opt_level: Option<u8>,
|
||||
pub num_intra_threads: Option<usize>,
|
||||
pub num_inter_threads: Option<usize>,
|
||||
// cpu
|
||||
pub cpu_arena_allocator: bool,
|
||||
// openvino
|
||||
pub openvino_dynamic_shapes: bool,
|
||||
pub openvino_opencl_throttling: bool,
|
||||
pub openvino_qdq_optimizer: bool,
|
||||
pub openvino_num_threads: Option<usize>,
|
||||
// onednn
|
||||
pub onednn_arena_allocator: bool,
|
||||
// tensorrt
|
||||
pub tensorrt_fp16: bool,
|
||||
pub tensorrt_engine_cache: bool,
|
||||
pub tensorrt_timing_cache: bool,
|
||||
// coreml
|
||||
pub coreml_static_input_shapes: bool,
|
||||
pub coreml_subgraph_running: bool,
|
||||
// cann
|
||||
pub cann_graph_inference: bool,
|
||||
pub cann_dump_graphs: bool,
|
||||
pub cann_dump_om_model: bool,
|
||||
// nnapi
|
||||
pub nnapi_cpu_only: bool,
|
||||
pub nnapi_disable_cpu: bool,
|
||||
pub nnapi_fp16: bool,
|
||||
pub nnapi_nchw: bool,
|
||||
// armnn
|
||||
pub armnn_arena_allocator: bool,
|
||||
// migraphx
|
||||
pub migraphx_fp16: bool,
|
||||
pub migraphx_exhaustive_tune: bool,
|
||||
}
|
||||
|
||||
impl Default for ORTConfig {
|
||||
@ -21,11 +53,33 @@ impl Default for ORTConfig {
|
||||
file: Default::default(),
|
||||
device: Default::default(),
|
||||
iiixs: Default::default(),
|
||||
graph_opt_level: Default::default(),
|
||||
spec: Default::default(),
|
||||
dtype: Default::default(),
|
||||
num_dry_run: 3,
|
||||
trt_fp16: true,
|
||||
graph_opt_level: Default::default(),
|
||||
num_intra_threads: None,
|
||||
num_inter_threads: None,
|
||||
cpu_arena_allocator: true,
|
||||
openvino_dynamic_shapes: true,
|
||||
openvino_opencl_throttling: true,
|
||||
openvino_qdq_optimizer: true,
|
||||
openvino_num_threads: None,
|
||||
coreml_static_input_shapes: false,
|
||||
coreml_subgraph_running: true,
|
||||
tensorrt_fp16: true,
|
||||
tensorrt_engine_cache: true,
|
||||
tensorrt_timing_cache: false,
|
||||
cann_graph_inference: true,
|
||||
cann_dump_graphs: false,
|
||||
cann_dump_om_model: false,
|
||||
onednn_arena_allocator: true,
|
||||
nnapi_cpu_only: false,
|
||||
nnapi_disable_cpu: false,
|
||||
nnapi_fp16: true,
|
||||
nnapi_nchw: false,
|
||||
armnn_arena_allocator: true,
|
||||
migraphx_fp16: true,
|
||||
migraphx_exhaustive_tune: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,10 +176,6 @@ macro_rules! impl_ort_config_methods {
|
||||
self.$field = self.$field.with_device(device);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _trt_fp16>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_trt_fp16(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _num_dry_run>](mut self, x: usize) -> Self {
|
||||
self.$field = self.$field.with_num_dry_run(x);
|
||||
self
|
||||
@ -134,6 +184,113 @@ macro_rules! impl_ort_config_methods {
|
||||
self.$field = self.$field.with_ixx(i, ii, x);
|
||||
self
|
||||
}
|
||||
// global
|
||||
pub fn [<with_ $field _graph_opt_level>](mut self, x: u8) -> Self {
|
||||
self.$field = self.$field.with_graph_opt_level(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _num_intra_threads>](mut self, x: usize) -> Self {
|
||||
self.$field = self.$field.with_num_intra_threads(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _num_inter_threads>](mut self, x: usize) -> Self {
|
||||
self.$field = self.$field.with_num_inter_threads(x);
|
||||
self
|
||||
}
|
||||
// cpu
|
||||
pub fn [<with_ $field _cpu_arena_allocator>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_cpu_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
// openvino
|
||||
pub fn [<with_ $field _openvino_dynamic_shapes>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_openvino_dynamic_shapes(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _openvino_opencl_throttling>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_openvino_opencl_throttling(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _openvino_qdq_optimizer>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_openvino_qdq_optimizer(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _openvino_num_threads>](mut self, x: usize) -> Self {
|
||||
self.$field = self.$field.with_openvino_num_threads(x);
|
||||
self
|
||||
}
|
||||
// onednn
|
||||
pub fn [<with_ $field _onednn_arena_allocator>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_onednn_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
|
||||
// tensorrt
|
||||
pub fn [<with_ $field _tensorrt_fp16>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_tensorrt_fp16(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _tensorrt_engine_cache>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_tensorrt_engine_cache(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _tensorrt_timing_cache>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_tensorrt_timing_cache(x);
|
||||
self
|
||||
}
|
||||
// coreml
|
||||
pub fn [<with_ $field _coreml_static_input_shapes>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_coreml_static_input_shapes(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _coreml_subgraph_running>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_coreml_subgraph_running(x);
|
||||
self
|
||||
}
|
||||
// cann
|
||||
pub fn [<with_ $field _cann_graph_inference>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_cann_graph_inference(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _cann_dump_graphs>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_cann_dump_graphs(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _cann_dump_om_model>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_cann_dump_om_model(x);
|
||||
self
|
||||
}
|
||||
// nnapi
|
||||
pub fn [<with_ $field _nnapi_cpu_only>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_nnapi_cpu_only(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _nnapi_disable_cpu>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_nnapi_disable_cpu(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _nnapi_fp16>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_nnapi_fp16(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _nnapi_nchw>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_nnapi_nchw(x);
|
||||
self
|
||||
}
|
||||
// armnn
|
||||
pub fn [<with_ $field _armnn_arena_allocator>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_armnn_arena_allocator(x);
|
||||
self
|
||||
}
|
||||
// migraphx
|
||||
pub fn [<with_ $field _migraphx_fp16>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_migraphx_fp16(x);
|
||||
self
|
||||
}
|
||||
pub fn [<with_ $field _migraphx_exhaustive_tune>](mut self, x: bool) -> Self {
|
||||
self.$field = self.$field.with_migraphx_exhaustive_tune(x);
|
||||
self
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -1,3 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd)]
|
||||
pub enum Scale {
|
||||
N,
|
||||
@ -64,10 +66,10 @@ impl TryFrom<char> for Scale {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Scale {
|
||||
type Error = anyhow::Error;
|
||||
impl FromStr for Scale {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"n" | "nano" => Ok(Self::N),
|
||||
"t" | "tiny" => Ok(Self::T),
|
||||
|
@ -1,3 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
|
||||
pub enum Task {
|
||||
/// Image classification task.
|
||||
@ -164,10 +166,10 @@ impl std::fmt::Display for Task {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Task {
|
||||
type Error = anyhow::Error;
|
||||
impl FromStr for Task {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// TODO
|
||||
match s.to_lowercase().as_str() {
|
||||
"cls" | "classify" | "classification" => Ok(Self::ImageClassification),
|
||||
|
@ -61,20 +61,24 @@ impl From<Color> for [u8; 3] {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Color {
|
||||
type Error = &'static str;
|
||||
impl std::str::FromStr for Color {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn try_from(x: &str) -> Result<Self, Self::Error> {
|
||||
fn from_str(x: &str) -> Result<Self, Self::Err> {
|
||||
let hex = x.trim_start_matches('#');
|
||||
let hex = match hex.len() {
|
||||
6 => format!("{}ff", hex),
|
||||
8 => hex.to_string(),
|
||||
_ => return Err("Failed to convert `Color` from str: invalid length"),
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to convert `Color` from str: invalid length"
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
u32::from_str_radix(&hex, 16)
|
||||
.map(Self)
|
||||
.map_err(|_| "Failed to convert `Color` from str: invalid hex")
|
||||
.map_err(|_| anyhow::anyhow!("Failed to convert `Color` from str: invalid hex"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,17 +155,8 @@ impl Color {
|
||||
xs.iter().copied().map(Into::into).collect()
|
||||
}
|
||||
|
||||
pub fn try_create_palette<A: TryInto<Self> + Copy>(xs: &[A]) -> Result<Vec<Self>>
|
||||
where
|
||||
<A as TryInto<Self>>::Error: std::fmt::Debug,
|
||||
{
|
||||
xs.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
x.try_into()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to convert: {:?}", e))
|
||||
})
|
||||
.collect()
|
||||
pub fn try_create_palette(xs: &[&str]) -> Result<Vec<Self>> {
|
||||
xs.iter().map(|x| x.parse()).collect()
|
||||
}
|
||||
|
||||
pub fn palette_rand(n: usize) -> Vec<Self> {
|
||||
|
@ -14,20 +14,22 @@ pub enum ColorMap256 {
|
||||
SmoothCoolWarm,
|
||||
}
|
||||
|
||||
impl From<&str> for ColorMap256 {
|
||||
fn from(s: &str) -> Self {
|
||||
impl std::str::FromStr for ColorMap256 {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"turbo" => Self::Turbo,
|
||||
"inferno" => Self::Inferno,
|
||||
"plasma" => Self::Plasma,
|
||||
"viridis" => Self::Viridis,
|
||||
"magma" => Self::Magma,
|
||||
"bentcoolwarm" => Self::BentCoolWarm,
|
||||
"blackbody" => Self::BlackBody,
|
||||
"extendedkindlmann" => Self::ExtendedKindLmann,
|
||||
"kindlmann" => Self::KindLmann,
|
||||
"smoothcoolwarm" => Self::SmoothCoolWarm,
|
||||
s => unimplemented!("{} is not supported for now!", s),
|
||||
"turbo" => Ok(Self::Turbo),
|
||||
"inferno" => Ok(Self::Inferno),
|
||||
"plasma" => Ok(Self::Plasma),
|
||||
"viridis" => Ok(Self::Viridis),
|
||||
"magma" => Ok(Self::Magma),
|
||||
"bentcoolwarm" => Ok(Self::BentCoolWarm),
|
||||
"blackbody" => Ok(Self::BlackBody),
|
||||
"extendedkindlmann" => Ok(Self::ExtendedKindLmann),
|
||||
"kindlmann" => Ok(Self::KindLmann),
|
||||
"smoothcoolwarm" => Ok(Self::SmoothCoolWarm),
|
||||
_ => Err(anyhow::anyhow!("Unsupported colormap: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user