From d3c738b5cf1a8c40edbcd67037effa2b98715179 Mon Sep 17 00:00:00 2001 From: jamjamjon Date: Tue, 20 May 2025 17:01:27 +0800 Subject: [PATCH] update --- examples/ben2/main.rs | 4 +- examples/blip/main.rs | 4 +- examples/classifier/main.rs | 12 +-- examples/clip/main.rs | 4 +- examples/d-fine/main.rs | 4 +- examples/db/main.rs | 6 +- examples/deim/main.rs | 4 +- examples/depth-anything/main.rs | 4 +- examples/depth-pro/main.rs | 4 +- examples/dinov2/main.rs | 4 +- examples/doclayout-yolo/main.rs | 4 +- examples/fast/main.rs | 8 +- examples/fastsam/main.rs | 4 +- examples/florence2/main.rs | 4 +- examples/grounding-dino/main.rs | 4 +- examples/linknet/main.rs | 8 +- examples/modnet/main.rs | 4 +- examples/moondream2/main.rs | 6 +- examples/owlv2/main.rs | 4 +- examples/picodet-layout/main.rs | 4 +- examples/rfdetr/main.rs | 4 +- examples/rmbg/main.rs | 6 +- examples/rtdetr/main.rs | 4 +- examples/rtmo/main.rs | 4 +- examples/sam/main.rs | 16 +-- examples/sam2/main.rs | 10 +- examples/sapiens/main.rs | 4 +- examples/slanet/main.rs | 4 +- examples/smolvlm/main.rs | 6 +- examples/svtr/main.rs | 4 +- examples/trocr/main.rs | 10 +- examples/ultralytics-rtdetr/main.rs | 4 +- examples/yolo-sam2/main.rs | 6 +- examples/yolo/main.rs | 9 +- examples/yoloe/main.rs | 4 +- examples/yolop/main.rs | 4 +- src/inference/engine.rs | 8 +- src/inference/hbb.rs | 4 +- src/inference/keypoint.rs | 1 - src/inference/mask.rs | 1 - src/inference/mod.rs | 4 - src/inference/obb.rs | 2 +- src/inference/polygon.rs | 3 +- src/models/beit/config.rs | 2 +- src/models/ben2/config.rs | 2 +- src/models/blip/config.rs | 2 +- src/models/blip/impl.rs | 4 +- src/models/clip/config.rs | 2 +- src/models/clip/impl.rs | 4 +- src/models/convnext/config.rs | 2 +- src/models/d_fine/config.rs | 2 +- src/models/db/config.rs | 2 +- src/models/db/impl.rs | 5 +- src/models/deim/config.rs | 2 +- src/models/deit/config.rs | 2 +- src/models/depth_anything/config.rs | 2 +- src/models/depth_anything/impl.rs | 4 +- src/models/depth_pro/config.rs | 2 +- src/models/depth_pro/impl.rs | 4 +- src/models/dinov2/config.rs | 2 +- src/models/dinov2/impl.rs | 4 +- src/models/fast/config.rs | 2 +- src/models/fastvit/config.rs | 2 +- src/models/florence2/config.rs | 2 +- src/models/florence2/impl.rs | 4 +- src/models/grounding_dino/config.rs | 2 +- src/models/grounding_dino/impl.rs | 4 +- src/models/linknet/config.rs | 2 +- src/models/mobileone/config.rs | 2 +- src/models/modnet/config.rs | 2 +- src/models/modnet/impl.rs | 4 +- src/models/moondream2/config.rs | 2 +- src/models/moondream2/impl.rs | 6 +- src/models/owl/config.rs | 2 +- src/models/owl/impl.rs | 4 +- src/models/picodet/config.rs | 2 +- src/models/picodet/impl.rs | 4 +- src/models/pipeline/basemodel.rs | 6 +- src/models/pipeline/image_classifier.rs | 8 +- src/models/rfdetr/config.rs | 2 +- src/models/rfdetr/impl.rs | 4 +- src/models/rmbg/config.rs | 2 +- src/models/rmbg/impl.rs | 4 +- src/models/rtdetr/config.rs | 2 +- src/models/rtdetr/impl.rs | 4 +- src/models/rtmo/config.rs | 2 +- src/models/rtmo/impl.rs | 4 +- src/models/sam/config.rs | 4 +- src/models/sam/impl.rs | 5 +- src/models/sam2/config.rs | 4 +- src/models/sam2/impl.rs | 4 +- src/models/sapiens/config.rs | 2 +- src/models/sapiens/impl.rs | 4 +- src/models/slanet/config.rs | 2 +- src/models/slanet/impl.rs | 4 +- src/models/smolvlm/config.rs | 2 +- src/models/smolvlm/impl.rs | 4 +- src/models/svtr/config.rs | 2 +- src/models/svtr/impl.rs | 4 +- src/models/trocr/config.rs | 2 +- src/models/trocr/impl.rs | 4 +- src/models/yolo/config.rs | 4 +- src/models/yolo/impl.rs | 14 +-- src/models/yolop/config.rs | 2 +- src/models/yolop/impl.rs | 4 +- .../model_config.rs => utils/config.rs} | 101 +++++++++++------- src/utils/mod.rs | 4 + src/utils/ops.rs | 2 +- .../engine_config.rs => utils/ort_config.rs} | 27 +++-- src/utils/processor_config.rs | 2 +- 110 files changed, 298 insertions(+), 264 deletions(-) rename src/{inference/model_config.rs => utils/config.rs} (76%) rename src/{inference/engine_config.rs => utils/ort_config.rs} (87%) diff --git a/examples/ben2/main.rs b/examples/ben2/main.rs index 5d0375c..8d44522 100644 --- a/examples/ben2/main.rs +++ b/examples/ben2/main.rs @@ -1,4 +1,4 @@ -use usls::{models::RMBG, Annotator, DataLoader, ModelConfig}; +use usls::{models::RMBG, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -20,7 +20,7 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::ben2_base() + let config = Config::ben2_base() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/blip/main.rs b/examples/blip/main.rs index 6367a67..6fd9456 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -1,4 +1,4 @@ -use usls::{models::Blip, DataLoader, ModelConfig}; +use usls::{models::Blip, Config, DataLoader}; #[derive(argh::FromArgs)] /// BLIP Example @@ -20,7 +20,7 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::blip_v1_base_caption() + let config = Config::blip_v1_base_caption() .with_device_all(args.device.as_str().try_into()?) .commit()?; let mut model = Blip::new(config)?; diff --git a/examples/classifier/main.rs b/examples/classifier/main.rs index cb0d987..d7bc9c1 100644 --- a/examples/classifier/main.rs +++ b/examples/classifier/main.rs @@ -1,4 +1,4 @@ -use usls::{models::ImageClassifier, Annotator, DataLoader, ModelConfig}; +use usls::{models::ImageClassifier, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -37,11 +37,11 @@ fn main() -> anyhow::Result<()> { // build model let config = match args.model.to_lowercase().as_str() { - "beit" => ModelConfig::beit_base(), - "convnext" => ModelConfig::convnext_v2_atto(), - "deit" => ModelConfig::deit_tiny_distill(), - "fastvit" => ModelConfig::fastvit_t8_distill(), - "mobileone" => ModelConfig::mobileone_s0(), + "beit" => Config::beit_base(), + "convnext" => Config::convnext_v2_atto(), + "deit" => Config::deit_tiny_distill(), + "fastvit" => Config::fastvit_t8_distill(), + "mobileone" => Config::mobileone_s0(), _ => anyhow::bail!("Unsupported model: {}", args.model), }; diff --git a/examples/clip/main.rs b/examples/clip/main.rs index c650d00..90d1055 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Clip, DataLoader, ModelConfig, Ops}; +use usls::{models::Clip, Config, DataLoader, Ops}; #[derive(argh::FromArgs)] /// CLIP Example @@ -17,7 +17,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::jina_clip_v1() + let config = Config::jina_clip_v1() .with_device_all(args.device.as_str().try_into()?) .commit()?; let mut model = Clip::new(config)?; diff --git a/examples/d-fine/main.rs b/examples/d-fine/main.rs index ffdfaca..b7c1548 100644 --- a/examples/d-fine/main.rs +++ b/examples/d-fine/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; +use usls::{models::RTDETR, Annotator, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // config - let mut model = RTDETR::new(ModelConfig::d_fine_n_coco().commit()?)?; + let mut model = RTDETR::new(Config::d_fine_n_coco().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/db/main.rs b/examples/db/main.rs index 2f92e65..c4e6373 100644 --- a/examples/db/main.rs +++ b/examples/db/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DB, Annotator, DataLoader, ModelConfig, Style}; +use usls::{models::DB, Annotator, Config, DataLoader, Style}; #[derive(argh::FromArgs)] /// Example @@ -42,8 +42,8 @@ fn main() -> Result<()> { // build model let config = match &args.model { - Some(m) => ModelConfig::db().with_model_file(m), - None => ModelConfig::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?), + Some(m) => Config::db().with_model_file(m), + None => Config::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?), } .with_device_all(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/deim/main.rs b/examples/deim/main.rs index 6fc358f..980f518 100644 --- a/examples/deim/main.rs +++ b/examples/deim/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; +use usls::{models::RTDETR, Annotator, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // config - let mut model = RTDETR::new(ModelConfig::deim_dfine_s_coco().commit()?)?; + let mut model = RTDETR::new(Config::deim_dfine_s_coco().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index 3981ce2..2523063 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DepthAnything, Annotator, DataLoader, ModelConfig, Style}; +use usls::{models::DepthAnything, Annotator, Config, DataLoader, Style}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let mut model = DepthAnything::new(ModelConfig::depth_anything_v2_small().commit()?)?; + let mut model = DepthAnything::new(Config::depth_anything_v2_small().commit()?)?; // load let xs = DataLoader::try_read_n(&["images/street.jpg"])?; diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs index c167d56..84002e6 100644 --- a/examples/depth-pro/main.rs +++ b/examples/depth-pro/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::DepthPro, Annotator, ModelConfig, Style}; +use usls::{models::DepthPro, Annotator, Config, Style}; #[derive(argh::FromArgs)] /// Example @@ -23,7 +23,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // model - let config = ModelConfig::depth_pro() + let config = Config::depth_pro() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/dinov2/main.rs b/examples/dinov2/main.rs index f9eda41..3c23a91 100644 --- a/examples/dinov2/main.rs +++ b/examples/dinov2/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DINOv2, DataLoader, ModelConfig}; +use usls::{models::DINOv2, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -11,7 +11,7 @@ fn main() -> Result<()> { let xs = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/bus.jpg"])?; // model - let config = ModelConfig::dinov2_small() + let config = Config::dinov2_small() .with_batch_size_all(xs.len()) .commit()?; let mut model = DINOv2::new(config)?; diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs index 3b2c467..52eb13d 100644 --- a/examples/doclayout-yolo/main.rs +++ b/examples/doclayout-yolo/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; +use usls::{models::YOLO, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -18,7 +18,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::doclayout_yolo_docstructbench() + let config = Config::doclayout_yolo_docstructbench() .with_model_device(args.device.as_str().try_into()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/examples/fast/main.rs b/examples/fast/main.rs index 493f737..875b466 100644 --- a/examples/fast/main.rs +++ b/examples/fast/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DB, Annotator, DataLoader, ModelConfig, Scale, Style}; +use usls::{models::DB, Annotator, Config, DataLoader, Scale, Style}; #[derive(argh::FromArgs)] /// Example @@ -27,9 +27,9 @@ fn main() -> Result<()> { // build model let config = match args.scale.as_str().try_into()? { - Scale::T => ModelConfig::fast_tiny(), - Scale::S => ModelConfig::fast_small(), - Scale::B => ModelConfig::fast_base(), + Scale::T => Config::fast_tiny(), + Scale::S => Config::fast_small(), + Scale::B => Config::fast_base(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }; let mut model = DB::new( diff --git a/examples/fastsam/main.rs b/examples/fastsam/main.rs index eaacb74..f5aa616 100644 --- a/examples/fastsam/main.rs +++ b/examples/fastsam/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; +use usls::{models::YOLO, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::fastsam_s() + let config = Config::fastsam_s() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 63bf33c..3ab421f 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Florence2, Annotator, DataLoader, ModelConfig, Style, Task}; +use usls::{models::Florence2, Annotator, Config, DataLoader, Style, Task}; #[derive(argh::FromArgs)] /// Example @@ -25,7 +25,7 @@ fn main() -> Result<()> { let xs = DataLoader::try_read_n(&["images/green-car.jpg", "assets/bus.jpg"])?; // build model - let config = ModelConfig::florence2_base() + let config = Config::florence2_base() .with_dtype_all(args.dtype.as_str().try_into()?) .with_device_all(args.device.as_str().try_into()?) .with_batch_size_all(xs.len()) diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index e155428..1099963 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::GroundingDINO, Annotator, DataLoader, ModelConfig}; +use usls::{models::GroundingDINO, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -45,7 +45,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); - let config = ModelConfig::grounding_dino_tiny() + let config = Config::grounding_dino_tiny() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) diff --git a/examples/linknet/main.rs b/examples/linknet/main.rs index f7c85ad..b8a4523 100644 --- a/examples/linknet/main.rs +++ b/examples/linknet/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::DB, Annotator, ModelConfig, Scale, Style}; +use usls::{models::DB, Annotator, Config, Scale, Style}; #[derive(argh::FromArgs)] /// Example @@ -28,9 +28,9 @@ fn main() -> Result<()> { // build model let config = match args.scale.as_str().try_into()? { - Scale::T => ModelConfig::linknet_r18(), - Scale::S => ModelConfig::linknet_r34(), - Scale::B => ModelConfig::linknet_r50(), + Scale::T => Config::linknet_r18(), + Scale::S => Config::linknet_r34(), + Scale::B => Config::linknet_r50(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }; let mut model = DB::new( diff --git a/examples/modnet/main.rs b/examples/modnet/main.rs index 9dde076..0c09036 100644 --- a/examples/modnet/main.rs +++ b/examples/modnet/main.rs @@ -1,4 +1,4 @@ -use usls::{models::MODNet, Annotator, DataLoader, ModelConfig}; +use usls::{models::MODNet, Annotator, Config, DataLoader}; fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() @@ -7,7 +7,7 @@ fn main() -> anyhow::Result<()> { .init(); // build model - let mut model = MODNet::new(ModelConfig::modnet_photographic().commit()?)?; + let mut model = MODNet::new(Config::modnet_photographic().commit()?)?; // load image let xs = DataLoader::try_read_n(&["images/liuyifei.png"])?; diff --git a/examples/moondream2/main.rs b/examples/moondream2/main.rs index 907a9aa..f23e9cf 100644 --- a/examples/moondream2/main.rs +++ b/examples/moondream2/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Moondream2, Annotator, DataLoader, ModelConfig, Scale, Task}; +use usls::{models::Moondream2, Annotator, Config, DataLoader, Scale, Task}; #[derive(argh::FromArgs)] /// Example @@ -40,8 +40,8 @@ fn main() -> Result<()> { // build model let config = match args.scale.as_str().try_into()? { - Scale::Billion(0.5) => ModelConfig::moondream2_0_5b(), - Scale::Billion(2.) => ModelConfig::moondream2_2b(), + Scale::Billion(0.5) => Config::moondream2_0_5b(), + Scale::Billion(2.) => Config::moondream2_2b(), _ => unimplemented!(), } .with_dtype_all(args.dtype.as_str().try_into()?) diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs index 3a0c6b4..0037d50 100644 --- a/examples/owlv2/main.rs +++ b/examples/owlv2/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::OWLv2, Annotator, ModelConfig}; +use usls::{models::OWLv2, Annotator, Config}; #[derive(argh::FromArgs)] /// Example @@ -47,7 +47,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // config - let config = ModelConfig::owlv2_base_ensemble() + let config = Config::owlv2_base_ensemble() // owlv2_base() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) diff --git a/examples/picodet-layout/main.rs b/examples/picodet-layout/main.rs index 5034a54..cbae280 100644 --- a/examples/picodet-layout/main.rs +++ b/examples/picodet-layout/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::PicoDet, Annotator, ModelConfig}; +use usls::{models::PicoDet, Annotator, Config}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -9,7 +9,7 @@ fn main() -> Result<()> { .init(); // config - let config = ModelConfig::picodet_layout_1x().commit()?; + let config = Config::picodet_layout_1x().commit()?; // picodet_l_layout_3cls() // picodet_l_layout_17cls() let mut model = PicoDet::new(config)?; diff --git a/examples/rfdetr/main.rs b/examples/rfdetr/main.rs index eb39ef8..79b544e 100644 --- a/examples/rfdetr/main.rs +++ b/examples/rfdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RFDETR, Annotator, DataLoader, ModelConfig}; +use usls::{models::RFDETR, Annotator, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // config - let mut model = RFDETR::new(ModelConfig::rfdetr_base().commit()?)?; + let mut model = RFDETR::new(Config::rfdetr_base().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/rmbg/main.rs b/examples/rmbg/main.rs index 4fbdfbd..b1456d9 100644 --- a/examples/rmbg/main.rs +++ b/examples/rmbg/main.rs @@ -1,4 +1,4 @@ -use usls::{models::RMBG, Annotator, DataLoader, ModelConfig}; +use usls::{models::RMBG, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -24,8 +24,8 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); let config = match args.ver { - 1.4 => ModelConfig::rmbg1_4(), - 2.0 => ModelConfig::rmbg2_0(), + 1.4 => Config::rmbg1_4(), + 2.0 => Config::rmbg2_0(), _ => unreachable!("Unsupported version"), }; diff --git a/examples/rtdetr/main.rs b/examples/rtdetr/main.rs index 8a19510..1f73704 100644 --- a/examples/rtdetr/main.rs +++ b/examples/rtdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; +use usls::{models::RTDETR, Annotator, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // config - let config = ModelConfig::rtdetr_v2_s_coco().commit()?; + let config = Config::rtdetr_v2_s_coco().commit()?; // rtdetr_v1_r18vd_coco() // rtdetr_v2_ms_coco() // rtdetr_v2_m_coco() diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index c2dde49..b68aef1 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTMO, Annotator, DataLoader, ModelConfig, Style, SKELETON_COCO_19}; +use usls::{models::RTMO, Annotator, Config, DataLoader, Style, SKELETON_COCO_19}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let mut model = RTMO::new(ModelConfig::rtmo_s().commit()?)?; + let mut model = RTMO::new(Config::rtmo_s().commit()?)?; // load image let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 0da9d7f..9a698d8 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamKind, SamPrompt, SAM}, - Annotator, DataLoader, ModelConfig, Scale, + Annotator, Config, DataLoader, Scale, }; #[derive(argh::FromArgs)] @@ -29,16 +29,16 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // Build model let config = match args.kind.as_str().try_into()? { - SamKind::Sam => ModelConfig::sam_v1_base(), + SamKind::Sam => Config::sam_v1_base(), SamKind::Sam2 => match args.scale.as_str().try_into()? { - Scale::T => ModelConfig::sam2_tiny(), - Scale::S => ModelConfig::sam2_small(), - Scale::B => ModelConfig::sam2_base_plus(), + Scale::T => Config::sam2_tiny(), + Scale::S => Config::sam2_small(), + Scale::B => Config::sam2_base_plus(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }, - SamKind::MobileSam => ModelConfig::mobile_sam_tiny(), - SamKind::SamHq => ModelConfig::sam_hq_tiny(), - SamKind::EdgeSam => ModelConfig::edge_sam_3x(), + SamKind::MobileSam => Config::mobile_sam_tiny(), + SamKind::SamHq => Config::sam_hq_tiny(), + SamKind::EdgeSam => Config::edge_sam_3x(), } .with_device_all(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/sam2/main.rs b/examples/sam2/main.rs index 79469f9..48eef6e 100644 --- a/examples/sam2/main.rs +++ b/examples/sam2/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamPrompt, SAM2}, - Annotator, DataLoader, ModelConfig, Scale, + Annotator, Config, DataLoader, Scale, }; #[derive(argh::FromArgs)] @@ -26,10 +26,10 @@ fn main() -> Result<()> { // Build model let config = match args.scale.as_str().try_into()? { - Scale::T => ModelConfig::sam2_1_tiny(), - Scale::S => ModelConfig::sam2_1_small(), - Scale::B => ModelConfig::sam2_1_base_plus(), - Scale::L => ModelConfig::sam2_1_large(), + Scale::T => Config::sam2_1_tiny(), + Scale::S => Config::sam2_1_small(), + Scale::B => Config::sam2_1_base_plus(), + Scale::L => Config::sam2_1_large(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale), } .with_device_all(args.device.as_str().try_into()?) diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs index dbc2c23..caf2d17 100644 --- a/examples/sapiens/main.rs +++ b/examples/sapiens/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Sapiens, Annotator, DataLoader, ModelConfig}; +use usls::{models::Sapiens, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -17,7 +17,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build - let config = ModelConfig::sapiens_seg_0_3b() + let config = Config::sapiens_seg_0_3b() .with_model_device(args.device.as_str().try_into()?) .commit()?; let mut model = Sapiens::new(config)?; diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs index fe66322..71a0fc7 100644 --- a/examples/slanet/main.rs +++ b/examples/slanet/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SLANet, Annotator, Color, DataLoader, ModelConfig}; +use usls::{models::SLANet, Annotator, Color, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -26,7 +26,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::slanet_lcnet_v2_mobile_ch() + let config = Config::slanet_lcnet_v2_mobile_ch() .with_model_device(args.device.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?) .commit()?; diff --git a/examples/smolvlm/main.rs b/examples/smolvlm/main.rs index 8063d4f..1f585cf 100644 --- a/examples/smolvlm/main.rs +++ b/examples/smolvlm/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SmolVLM, DataLoader, ModelConfig, Scale}; +use usls::{models::SmolVLM, Config, DataLoader, Scale}; #[derive(argh::FromArgs)] /// Example @@ -30,8 +30,8 @@ fn main() -> Result<()> { // build model let config = match args.scale.as_str().try_into()? { - Scale::Million(256.) => ModelConfig::smolvlm_256m(), - Scale::Million(500.) => ModelConfig::smolvlm_500m(), + Scale::Million(256.) => Config::smolvlm_256m(), + Scale::Million(500.) => Config::smolvlm_500m(), _ => unimplemented!(), } .with_device_all(args.device.as_str().try_into()?) diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index f39c464..779afb7 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SVTR, DataLoader, ModelConfig}; +use usls::{models::SVTR, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::ppocr_rec_v4_ch() + let config = Config::ppocr_rec_v4_ch() // ppocr_rec_v4_en() // repsvtr_ch() .with_model_device(args.device.as_str().try_into()?) diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs index 79aad83..8e5fcc8 100644 --- a/examples/trocr/main.rs +++ b/examples/trocr/main.rs @@ -1,6 +1,6 @@ use usls::{ models::{TrOCR, TrOCRKind}, - DataLoader, ModelConfig, Scale, + Config, DataLoader, Scale, }; #[derive(argh::FromArgs)] @@ -40,12 +40,12 @@ fn main() -> anyhow::Result<()> { // build model let config = match args.scale.as_str().try_into()? { Scale::S => match args.kind.as_str().try_into()? { - TrOCRKind::Printed => ModelConfig::trocr_small_printed(), - TrOCRKind::HandWritten => ModelConfig::trocr_small_handwritten(), + TrOCRKind::Printed => Config::trocr_small_printed(), + TrOCRKind::HandWritten => Config::trocr_small_handwritten(), }, Scale::B => match args.kind.as_str().try_into()? { - TrOCRKind::Printed => ModelConfig::trocr_base_printed(), - TrOCRKind::HandWritten => ModelConfig::trocr_base_handwritten(), + TrOCRKind::Printed => Config::trocr_base_printed(), + TrOCRKind::HandWritten => Config::trocr_base_handwritten(), }, x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), } diff --git a/examples/ultralytics-rtdetr/main.rs b/examples/ultralytics-rtdetr/main.rs index 6de6b8a..a67ebde 100644 --- a/examples/ultralytics-rtdetr/main.rs +++ b/examples/ultralytics-rtdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; +use usls::{models::YOLO, Annotator, Config, DataLoader}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = ModelConfig::ultralytics_rtdetr_l() + let config = Config::ultralytics_rtdetr_l() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs index 3632be5..5e9fb53 100644 --- a/examples/yolo-sam2/main.rs +++ b/examples/yolo-sam2/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamPrompt, SAM2, YOLO}, - Annotator, DataLoader, ModelConfig, Scale, Style, + Annotator, Config, DataLoader, Scale, Style, }; #[derive(argh::FromArgs)] @@ -21,10 +21,10 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build SAM - let mut sam = SAM2::new(ModelConfig::sam2_1_tiny().commit()?)?; + let mut sam = SAM2::new(Config::sam2_1_tiny().commit()?)?; // build YOLOv8 - let options_yolo = ModelConfig::yolo_detect() + let options_yolo = Config::yolo_detect() .with_scale(Scale::N) .with_version(8.into()) .with_model_device(args.device.as_str().try_into()?) diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 4c25ba1..e2f9007 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ - models::YOLO, Annotator, DataLoader, ModelConfig, Style, NAMES_COCO_80, - NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19, + models::YOLO, Annotator, Config, DataLoader, Style, NAMES_COCO_80, NAMES_COCO_KEYPOINTS_17, + NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19, }; #[derive(argh::FromArgs, Debug)] @@ -130,7 +130,7 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); let args: Args = argh::from_env(); - let mut config = ModelConfig::yolo() + let mut config = Config::yolo() .with_model_file(&args.model.unwrap_or_default()) .with_task(args.task.as_str().try_into()?) .with_version(args.ver.try_into()?) @@ -170,7 +170,8 @@ fn main() -> Result<()> { }) .with_topk(args.topk) .retain_classes(&args.retain_classes) - .exclude_classes(&args.exclude_classes); + .exclude_classes(&args.exclude_classes) + .with_model_num_dry_run(2); if args.use_coco_80_classes { config = config.with_class_names(&NAMES_COCO_80); } diff --git a/examples/yoloe/main.rs b/examples/yoloe/main.rs index 175cbb2..3fa3b2f 100644 --- a/examples/yoloe/main.rs +++ b/examples/yoloe/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, ModelConfig, Style}; +use usls::{models::YOLO, Annotator, Config, DataLoader, Style}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // config - let config = ModelConfig::yoloe_v8s_seg_pf() + let config = Config::yoloe_v8s_seg_pf() // yoloe_v8m_seg_pf() // yoloe_v8l_seg_pf() // yoloe_11s_seg_pf() diff --git a/examples/yolop/main.rs b/examples/yolop/main.rs index 767926b..78427fb 100644 --- a/examples/yolop/main.rs +++ b/examples/yolop/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLOPv2, Annotator, DataLoader, ModelConfig}; +use usls::{models::YOLOPv2, Annotator, Config, DataLoader}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let mut model = YOLOPv2::new(ModelConfig::yolop_v2_480x800().commit()?)?; + let mut model = YOLOPv2::new(Config::yolop_v2_480x800().commit()?)?; // load image let xs = DataLoader::try_read_n(&["images/car-view.jpg"])?; diff --git a/src/inference/engine.rs b/src/inference/engine.rs index cd2a9f3..2884a7c 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -13,8 +13,8 @@ use prost::Message; use std::collections::HashSet; use crate::{ - build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, EngineConfig, Iiix, - MinOptMax, Ops, Ts, Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X, + build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, Iiix, MinOptMax, + ORTConfig, Ops, Ts, Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X, }; impl From for DType { @@ -93,7 +93,7 @@ impl Default for Engine { } impl Engine { - pub fn try_from_config(config: &EngineConfig) -> Result { + pub fn try_from_config(config: &ORTConfig) -> Result { Self { file: config.file.clone(), spec: config.spec.clone(), @@ -101,7 +101,7 @@ impl Engine { device: config.device, trt_fp16: config.trt_fp16, num_dry_run: config.num_dry_run, - graph_opt_level: config.ort_graph_opt_level, + graph_opt_level: config.graph_opt_level, ..Default::default() } .build() diff --git a/src/inference/hbb.rs b/src/inference/hbb.rs index 12bcac3..04e5115 100644 --- a/src/inference/hbb.rs +++ b/src/inference/hbb.rs @@ -17,7 +17,9 @@ impl std::fmt::Debug for Hbb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Hbb") .field("xyxy", &[self.x, self.y, self.xmax(), self.ymax()]) - .field("meta", &self.meta) + .field("id", &self.meta.id()) + .field("name", &self.meta.name()) + .field("confidence", &self.meta.confidence()) .finish() } } diff --git a/src/inference/keypoint.rs b/src/inference/keypoint.rs index 3bf25e4..8ba23c3 100644 --- a/src/inference/keypoint.rs +++ b/src/inference/keypoint.rs @@ -22,7 +22,6 @@ impl std::fmt::Debug for Keypoint { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Keypoint") .field("xy", &[self.x, self.y]) - .field("uid", &self.meta.uid()) .field("id", &self.meta.id()) .field("name", &self.meta.name()) .field("confidence", &self.meta.confidence()) diff --git a/src/inference/mask.rs b/src/inference/mask.rs index 09ce1bf..47bb1c0 100644 --- a/src/inference/mask.rs +++ b/src/inference/mask.rs @@ -20,7 +20,6 @@ impl std::fmt::Debug for Mask { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Mask") .field("dimensions", &self.dimensions()) - .field("uid", &self.meta.uid()) .field("id", &self.meta.id()) .field("name", &self.meta.name()) .field("confidence", &self.meta.confidence()) diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 8e5783c..1cb35bd 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,12 +1,10 @@ #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] mod engine; -mod engine_config; mod hbb; mod image; mod instance_meta; mod keypoint; mod mask; -mod model_config; mod obb; mod polygon; mod prob; @@ -22,13 +20,11 @@ pub(crate) mod onnx { #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] pub use engine::*; -pub use engine_config::EngineConfig; pub use hbb::*; pub use image::*; pub use instance_meta::*; pub use keypoint::*; pub use mask::*; -pub use model_config::*; pub use obb::*; pub use polygon::*; pub use prob::*; diff --git a/src/inference/obb.rs b/src/inference/obb.rs index 0f2f781..1c034f4 100644 --- a/src/inference/obb.rs +++ b/src/inference/obb.rs @@ -13,7 +13,7 @@ pub struct Obb { impl std::fmt::Debug for Obb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Obb") - .field("uid", &self.meta.uid()) + .field("vertices", &self.vertices) .field("id", &self.meta.id()) .field("name", &self.meta.name()) .field("confidence", &self.meta.confidence()) diff --git a/src/inference/polygon.rs b/src/inference/polygon.rs index 1eecb96..3d97d7a 100644 --- a/src/inference/polygon.rs +++ b/src/inference/polygon.rs @@ -27,8 +27,7 @@ impl Default for Polygon { impl std::fmt::Debug for Polygon { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Polygon") - .field("count", &self.count()) - .field("uid", &self.meta.uid()) + .field("n_points", &self.count()) .field("id", &self.meta.id()) .field("name", &self.meta.name()) .field("confidence", &self.meta.confidence()) diff --git a/src/models/beit/config.rs b/src/models/beit/config.rs index 34eb389..b2668da 100644 --- a/src/models/beit/config.rs +++ b/src/models/beit/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `BEiT` -impl crate::ModelConfig { +impl crate::Config { pub fn beit() -> Self { Self::default() .with_name("beit") diff --git a/src/models/ben2/config.rs b/src/models/ben2/config.rs index f942025..9a40732 100644 --- a/src/models/ben2/config.rs +++ b/src/models/ben2/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `BEN2` -impl crate::ModelConfig { +impl crate::Config { pub fn ben2_base() -> Self { Self::rmbg().with_model_file("ben2-base.onnx") } diff --git a/src/models/blip/config.rs b/src/models/blip/config.rs index 0395d54..eb70c61 100644 --- a/src/models/blip/config.rs +++ b/src/models/blip/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `BLIP` -impl crate::ModelConfig { +impl crate::Config { #[allow(clippy::excessive_precision)] pub fn blip() -> Self { Self::default() diff --git a/src/models/blip/impl.rs b/src/models/blip/impl.rs index 3a4e922..933dc41 100644 --- a/src/models/blip/impl.rs +++ b/src/models/blip/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; -use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct Blip { @@ -18,7 +18,7 @@ pub struct Blip { } impl Blip { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let visual = Engine::try_from_config(&config.visual)?; let textual = Engine::try_from_config(&config.textual)?; let (batch, height, width) = ( diff --git a/src/models/clip/config.rs b/src/models/clip/config.rs index d0712b2..8b1d69f 100644 --- a/src/models/clip/config.rs +++ b/src/models/clip/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `CLIP` -impl crate::ModelConfig { +impl crate::Config { pub fn clip() -> Self { Self::default() .with_name("clip") diff --git a/src/models/clip/impl.rs b/src/models/clip/impl.rs index e3231b9..7d464c0 100644 --- a/src/models/clip/impl.rs +++ b/src/models/clip/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Array2; -use crate::{elapsed, Engine, Image, ModelConfig, Processor, Ts, X}; +use crate::{elapsed, Config, Engine, Image, Processor, Ts, X}; #[derive(Debug, Builder)] pub struct Clip { @@ -16,7 +16,7 @@ pub struct Clip { } impl Clip { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let visual = Engine::try_from_config(&config.visual)?; let textual = Engine::try_from_config(&config.textual)?; let (batch, height, width) = ( diff --git a/src/models/convnext/config.rs b/src/models/convnext/config.rs index 5b14c14..162d409 100644 --- a/src/models/convnext/config.rs +++ b/src/models/convnext/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `ConvNeXt` -impl crate::ModelConfig { +impl crate::Config { pub fn convnext() -> Self { Self::default() .with_name("convnext") diff --git a/src/models/d_fine/config.rs b/src/models/d_fine/config.rs index 16de585..32e6eae 100644 --- a/src/models/d_fine/config.rs +++ b/src/models/d_fine/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `d_fine` -impl crate::ModelConfig { +impl crate::Config { pub fn d_fine() -> Self { Self::rtdetr().with_name("d-fine") } diff --git a/src/models/db/config.rs b/src/models/db/config.rs index 0493237..d728dde 100644 --- a/src/models/db/config.rs +++ b/src/models/db/config.rs @@ -1,5 +1,5 @@ /// Model configuration for [DB](https://github.com/MhLiao/DB) and [PaddleOCR-Det](https://github.com/PaddlePaddle/PaddleOCR) -impl crate::ModelConfig { +impl crate::Config { pub fn db() -> Self { Self::default() .with_name("db") diff --git a/src/models/db/impl.rs b/src/models/db/impl.rs index af7d874..781a2aa 100644 --- a/src/models/db/impl.rs +++ b/src/models/db/impl.rs @@ -4,8 +4,7 @@ use ndarray::Axis; use rayon::prelude::*; use crate::{ - elapsed, DynConf, Engine, Hbb, Image, Mask, ModelConfig, Obb, Ops, Polygon, Processor, Ts, Xs, - Y, + elapsed, Config, DynConf, Engine, Hbb, Image, Mask, Obb, Ops, Polygon, Processor, Ts, Xs, Y, }; #[derive(Debug, Builder)] @@ -25,7 +24,7 @@ pub struct DB { } impl DB { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts, spec) = ( engine.batch().opt(), diff --git a/src/models/deim/config.rs b/src/models/deim/config.rs index 81177ae..8dd6a38 100644 --- a/src/models/deim/config.rs +++ b/src/models/deim/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `DEIM` -impl crate::ModelConfig { +impl crate::Config { pub fn deim() -> Self { Self::d_fine().with_name("deim") } diff --git a/src/models/deit/config.rs b/src/models/deit/config.rs index 1f7c3b9..798c205 100644 --- a/src/models/deit/config.rs +++ b/src/models/deit/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `DeiT` -impl crate::ModelConfig { +impl crate::Config { pub fn deit() -> Self { Self::default() .with_name("deit") diff --git a/src/models/depth_anything/config.rs b/src/models/depth_anything/config.rs index e8adefd..30668fd 100644 --- a/src/models/depth_anything/config.rs +++ b/src/models/depth_anything/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `DepthAnything` -impl crate::ModelConfig { +impl crate::Config { pub fn depth_anything() -> Self { Self::default() .with_name("depth-anything") diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs index 778fcc8..962f52f 100644 --- a/src/models/depth_anything/impl.rs +++ b/src/models/depth_anything/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct DepthAnything { @@ -15,7 +15,7 @@ pub struct DepthAnything { } impl DepthAnything { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); diff --git a/src/models/depth_pro/config.rs b/src/models/depth_pro/config.rs index e681f80..2d569c7 100644 --- a/src/models/depth_pro/config.rs +++ b/src/models/depth_pro/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `DepthPro` -impl crate::ModelConfig { +impl crate::Config { pub fn depth_pro() -> Self { Self::default() .with_name("depth-pro") diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs index 6301437..cbb419e 100644 --- a/src/models/depth_pro/impl.rs +++ b/src/models/depth_pro/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct DepthPro { @@ -16,7 +16,7 @@ pub struct DepthPro { } impl DepthPro { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/dinov2/config.rs b/src/models/dinov2/config.rs index 60df927..22eb541 100644 --- a/src/models/dinov2/config.rs +++ b/src/models/dinov2/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `DINOv2` -impl crate::ModelConfig { +impl crate::Config { pub fn dinov2() -> Self { Self::default() .with_name("dinov2") diff --git a/src/models/dinov2/impl.rs b/src/models/dinov2/impl.rs index c9a070c..53e3b74 100644 --- a/src/models/dinov2/impl.rs +++ b/src/models/dinov2/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, ModelConfig, Processor, Scale, Ts, Xs, X}; +use crate::{elapsed, Config, Engine, Image, Processor, Scale, Ts, Xs, X}; #[derive(Builder, Debug)] pub struct DINOv2 { @@ -15,7 +15,7 @@ pub struct DINOv2 { } impl DINOv2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/fast/config.rs b/src/models/fast/config.rs index 02c4250..2277ff5 100644 --- a/src/models/fast/config.rs +++ b/src/models/fast/config.rs @@ -1,5 +1,5 @@ /// Model configuration for [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://github.com/czczup/FAST) -impl crate::ModelConfig { +impl crate::Config { pub fn fast() -> Self { Self::db() .with_name("fast") diff --git a/src/models/fastvit/config.rs b/src/models/fastvit/config.rs index 10038ef..351e790 100644 --- a/src/models/fastvit/config.rs +++ b/src/models/fastvit/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `FastViT` -impl crate::ModelConfig { +impl crate::Config { pub fn fastvit() -> Self { Self::default() .with_name("fastvit") diff --git a/src/models/florence2/config.rs b/src/models/florence2/config.rs index e1fac46..5da2c99 100644 --- a/src/models/florence2/config.rs +++ b/src/models/florence2/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `Florence2` -impl crate::ModelConfig { +impl crate::Config { pub fn florence2() -> Self { Self::default() .with_name("florence2") diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs index 3491485..4404327 100644 --- a/src/models/florence2/impl.rs +++ b/src/models/florence2/impl.rs @@ -4,7 +4,7 @@ use ndarray::{s, Axis}; use rayon::prelude::*; use crate::{ - elapsed, models::Quantizer, Engine, Hbb, Image, LogitsSampler, ModelConfig, Polygon, Processor, + elapsed, models::Quantizer, Config, Engine, Hbb, Image, LogitsSampler, Polygon, Processor, Scale, Task, Ts, Xs, X, Y, }; @@ -28,7 +28,7 @@ pub struct Florence2 { } impl Florence2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let vision_encoder = Engine::try_from_config(&config.visual)?; let text_embed = Engine::try_from_config(&config.textual)?; let encoder = Engine::try_from_config(&config.textual_encoder)?; diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs index 1ff2b34..375732c 100644 --- a/src/models/grounding_dino/config.rs +++ b/src/models/grounding_dino/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `GroundingDino` -impl crate::ModelConfig { +impl crate::Config { pub fn grounding_dino() -> Self { Self::default() .with_name("grounding-dino") diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs index 8d6d62f..efebd19 100644 --- a/src/models/grounding_dino/impl.rs +++ b/src/models/grounding_dino/impl.rs @@ -4,7 +4,7 @@ use ndarray::{s, Array2, Axis}; use rayon::prelude::*; use std::fmt::Write; -use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; #[derive(Builder, Debug)] pub struct GroundingDINO { @@ -24,7 +24,7 @@ pub struct GroundingDINO { } impl GroundingDINO { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/linknet/config.rs b/src/models/linknet/config.rs index 9c952b1..ecdc71e 100644 --- a/src/models/linknet/config.rs +++ b/src/models/linknet/config.rs @@ -1,5 +1,5 @@ /// Model configuration for [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/abs/1707.03718) -impl crate::ModelConfig { +impl crate::Config { pub fn linknet() -> Self { Self::fast() .with_name("linknet") diff --git a/src/models/mobileone/config.rs b/src/models/mobileone/config.rs index 1190d46..a1efa79 100644 --- a/src/models/mobileone/config.rs +++ b/src/models/mobileone/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `MobileOne` -impl crate::ModelConfig { +impl crate::Config { pub fn mobileone() -> Self { Self::default() .with_name("mobileone") diff --git a/src/models/modnet/config.rs b/src/models/modnet/config.rs index c72dd53..7d2c978 100644 --- a/src/models/modnet/config.rs +++ b/src/models/modnet/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `MODNet` -impl crate::ModelConfig { +impl crate::Config { pub fn modnet() -> Self { Self::default() .with_name("modnet") diff --git a/src/models/modnet/impl.rs b/src/models/modnet/impl.rs index 9dff8b3..0d123f7 100644 --- a/src/models/modnet/impl.rs +++ b/src/models/modnet/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct MODNet { @@ -16,7 +16,7 @@ pub struct MODNet { } impl MODNet { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/moondream2/config.rs b/src/models/moondream2/config.rs index d5c7642..a68bfdc 100644 --- a/src/models/moondream2/config.rs +++ b/src/models/moondream2/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `moondream2` -impl crate::ModelConfig { +impl crate::Config { pub fn moondream2() -> Self { Self::default() .with_name("moondream2") diff --git a/src/models/moondream2/impl.rs b/src/models/moondream2/impl.rs index fd1100f..0c2aa7e 100644 --- a/src/models/moondream2/impl.rs +++ b/src/models/moondream2/impl.rs @@ -5,8 +5,8 @@ use ndarray::{s, Array, Array2, Array3, Axis, IxDyn}; use ndarray_npy::ReadNpyExt; use crate::{ - DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, ModelConfig, Processor, Scale, Task, - Xs, X, Y, + Config, DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, Processor, Scale, Task, Xs, X, + Y, }; #[derive(Builder, Debug)] @@ -32,7 +32,7 @@ pub struct Moondream2 { } impl Moondream2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let max_length = 2048; let max_objects = 50; let eos_token_id = 50256; diff --git a/src/models/owl/config.rs b/src/models/owl/config.rs index 2520c1b..9176a80 100644 --- a/src/models/owl/config.rs +++ b/src/models/owl/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `OWLv2` -impl crate::ModelConfig { +impl crate::Config { pub fn owlv2() -> Self { Self::default() .with_name("owlv2") diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs index 1852da8..ad7a4e1 100644 --- a/src/models/owl/impl.rs +++ b/src/models/owl/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct OWLv2 { @@ -22,7 +22,7 @@ pub struct OWLv2 { } impl OWLv2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/picodet/config.rs b/src/models/picodet/config.rs index 3ca083c..073f755 100644 --- a/src/models/picodet/config.rs +++ b/src/models/picodet/config.rs @@ -4,7 +4,7 @@ use crate::{ }; /// Model configuration for `PicoDet` -impl crate::ModelConfig { +impl crate::Config { pub fn picodet() -> Self { Self::default() .with_name("picodet") diff --git a/src/models/picodet/impl.rs b/src/models/picodet/impl.rs index 1700ae1..ce82629 100644 --- a/src/models/picodet/impl.rs +++ b/src/models/picodet/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct PicoDet { @@ -19,7 +19,7 @@ pub struct PicoDet { } impl PicoDet { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/pipeline/basemodel.rs b/src/models/pipeline/basemodel.rs index 49430a0..724f481 100644 --- a/src/models/pipeline/basemodel.rs +++ b/src/models/pipeline/basemodel.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use crate::{ - elapsed, DType, Device, Engine, Image, ModelConfig, Processor, Scale, Task, Ts, Version, Xs, X, + elapsed, Config, DType, Device, Engine, Image, Processor, Scale, Task, Ts, Version, Xs, X, }; #[derive(Debug, Builder)] @@ -27,7 +27,7 @@ impl BaseModelVisual { self.ts.summary(); } - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let err_msg = "You need to specify the image height and image width for visual model."; let (batch, height, width, ts, spec) = ( @@ -103,7 +103,7 @@ impl BaseModelTextual { self.ts.summary(); } - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, ts, spec) = ( engine.batch().opt(), diff --git a/src/models/pipeline/image_classifier.rs b/src/models/pipeline/image_classifier.rs index dcbfccf..e851740 100644 --- a/src/models/pipeline/image_classifier.rs +++ b/src/models/pipeline/image_classifier.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, Engine, Image, ModelConfig, Prob, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Prob, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct ImageClassifier { @@ -19,16 +19,16 @@ pub struct ImageClassifier { ts: Ts, } -impl TryFrom for ImageClassifier { +impl TryFrom for ImageClassifier { type Error = anyhow::Error; - fn try_from(config: ModelConfig) -> Result { + fn try_from(config: Config) -> Result { Self::new(config) } } impl ImageClassifier { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/rfdetr/config.rs b/src/models/rfdetr/config.rs index 85f1220..780faab 100644 --- a/src/models/rfdetr/config.rs +++ b/src/models/rfdetr/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_COCO_91; /// Model configuration for `RT-DETR` -impl crate::ModelConfig { +impl crate::Config { pub fn rfdetr() -> Self { Self::default() .with_name("rfdetr") diff --git a/src/models/rfdetr/impl.rs b/src/models/rfdetr/impl.rs index 5293266..e5c2b5d 100644 --- a/src/models/rfdetr/impl.rs +++ b/src/models/rfdetr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct RFDETR { @@ -19,7 +19,7 @@ pub struct RFDETR { } impl RFDETR { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/rmbg/config.rs b/src/models/rmbg/config.rs index 2bb963d..c0f9caa 100644 --- a/src/models/rmbg/config.rs +++ b/src/models/rmbg/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `RMBG` -impl crate::ModelConfig { +impl crate::Config { pub fn rmbg() -> Self { Self::default() .with_name("rmbg") diff --git a/src/models/rmbg/impl.rs b/src/models/rmbg/impl.rs index ea4d9c6..34032e2 100644 --- a/src/models/rmbg/impl.rs +++ b/src/models/rmbg/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct RMBG { @@ -15,7 +15,7 @@ pub struct RMBG { } impl RMBG { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/rtdetr/config.rs b/src/models/rtdetr/config.rs index 56f8da4..ab87f23 100644 --- a/src/models/rtdetr/config.rs +++ b/src/models/rtdetr/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_COCO_80; /// Model configuration for `RT-DETR` -impl crate::ModelConfig { +impl crate::Config { pub fn rtdetr() -> Self { Self::default() .with_name("rtdetr") diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs index 2b0675c..a75b09b 100644 --- a/src/models/rtdetr/impl.rs +++ b/src/models/rtdetr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct RTDETR { @@ -19,7 +19,7 @@ pub struct RTDETR { } impl RTDETR { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs index 43bcb24..03f1d44 100644 --- a/src/models/rtmo/config.rs +++ b/src/models/rtmo/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `RTMO` -impl crate::ModelConfig { +impl crate::Config { pub fn rtmo() -> Self { Self::default() .with_name("rtmo") diff --git a/src/models/rtmo/impl.rs b/src/models/rtmo/impl.rs index de5227a..14cf442 100644 --- a/src/models/rtmo/impl.rs +++ b/src/models/rtmo/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Keypoint, ModelConfig, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, DynConf, Engine, Hbb, Image, Keypoint, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct RTMO { @@ -18,7 +18,7 @@ pub struct RTMO { } impl RTMO { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs index f46e86c..bd03825 100644 --- a/src/models/sam/config.rs +++ b/src/models/sam/config.rs @@ -1,7 +1,7 @@ -use crate::{models::SamKind, ModelConfig}; +use crate::{models::SamKind, Config}; /// Model configuration for `Segment Anything Model` -impl ModelConfig { +impl Config { pub fn sam() -> Self { Self::default() .with_name("sam") diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs index fea747d..d022295 100644 --- a/src/models/sam/impl.rs +++ b/src/models/sam/impl.rs @@ -4,8 +4,7 @@ use ndarray::{s, Axis}; use rand::prelude::*; use crate::{ - elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, SamPrompt, Ts, Xs, - X, Y, + elapsed, Config, DynConf, Engine, Image, Mask, Ops, Polygon, Processor, SamPrompt, Ts, Xs, X, Y, }; #[derive(Debug, Clone)] @@ -49,7 +48,7 @@ pub struct SAM { } impl SAM { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.encoder)?; let decoder = Engine::try_from_config(&config.decoder)?; diff --git a/src/models/sam2/config.rs b/src/models/sam2/config.rs index f58f7a7..1ca09c8 100644 --- a/src/models/sam2/config.rs +++ b/src/models/sam2/config.rs @@ -1,7 +1,7 @@ -use crate::ModelConfig; +use crate::Config; /// Model configuration for `SAM2.1` -impl ModelConfig { +impl Config { pub fn sam2_1_tiny() -> Self { Self::sam() .with_encoder_file("sam2.1-hiera-tiny-encoder.onnx") diff --git a/src/models/sam2/impl.rs b/src/models/sam2/impl.rs index 21c9412..822ad06 100644 --- a/src/models/sam2/impl.rs +++ b/src/models/sam2/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use crate::{ - elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Processor, SamPrompt, Ts, Xs, X, Y, + elapsed, Config, DynConf, Engine, Image, Mask, Ops, Processor, SamPrompt, Ts, Xs, X, Y, }; #[derive(Builder, Debug)] @@ -20,7 +20,7 @@ pub struct SAM2 { } impl SAM2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.encoder)?; let decoder = Engine::try_from_config(&config.decoder)?; let (batch, height, width) = ( diff --git a/src/models/sapiens/config.rs b/src/models/sapiens/config.rs index a51fbc7..55547ff 100644 --- a/src/models/sapiens/config.rs +++ b/src/models/sapiens/config.rs @@ -1,7 +1,7 @@ use crate::NAMES_BODY_PARTS_28; /// Model configuration for `Sapiens` -impl crate::ModelConfig { +impl crate::Config { pub fn sapiens() -> Self { Self::default() .with_name("sapiens") diff --git a/src/models/sapiens/impl.rs b/src/models/sapiens/impl.rs index a50cd58..577e756 100644 --- a/src/models/sapiens/impl.rs +++ b/src/models/sapiens/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Array2, Axis}; -use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, Task, Ts, Xs, Y}; +use crate::{elapsed, Config, Engine, Image, Mask, Ops, Polygon, Processor, Task, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct Sapiens { @@ -18,7 +18,7 @@ pub struct Sapiens { } impl Sapiens { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/models/slanet/config.rs b/src/models/slanet/config.rs index d045fdc..ded9884 100644 --- a/src/models/slanet/config.rs +++ b/src/models/slanet/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `SLANet` -impl crate::ModelConfig { +impl crate::Config { pub fn slanet() -> Self { Self::default() .with_name("slanet") diff --git a/src/models/slanet/impl.rs b/src/models/slanet/impl.rs index 54da596..5df6b01 100644 --- a/src/models/slanet/impl.rs +++ b/src/models/slanet/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; -use crate::{elapsed, models::BaseModelVisual, Image, Keypoint, ModelConfig, Ts, Xs, Y}; +use crate::{elapsed, models::BaseModelVisual, Config, Image, Keypoint, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct SLANet { @@ -19,7 +19,7 @@ impl SLANet { self.ts.summary(); } - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let base = BaseModelVisual::new(config)?; let spec = base.engine().spec().to_owned(); let sos = 0; diff --git a/src/models/smolvlm/config.rs b/src/models/smolvlm/config.rs index 6a41fe9..dcbdc9e 100644 --- a/src/models/smolvlm/config.rs +++ b/src/models/smolvlm/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `SmolVLM` -impl crate::ModelConfig { +impl crate::Config { pub fn smolvlm() -> Self { Self::default() .with_name("smolvlm") diff --git a/src/models/smolvlm/impl.rs b/src/models/smolvlm/impl.rs index 3039a0b..c5df584 100644 --- a/src/models/smolvlm/impl.rs +++ b/src/models/smolvlm/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use image::GenericImageView; use ndarray::s; -use crate::{Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y}; +use crate::{Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct SmolVLM { @@ -32,7 +32,7 @@ pub struct SmolVLM { } impl SmolVLM { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let vision = Engine::try_from_config(&config.visual)?; let text_embed = Engine::try_from_config(&config.textual)?; let decoder = Engine::try_from_config(&config.textual_decoder_merged)?; diff --git a/src/models/svtr/config.rs b/src/models/svtr/config.rs index 583b7f5..454c8f1 100644 --- a/src/models/svtr/config.rs +++ b/src/models/svtr/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `SVTR` -impl crate::ModelConfig { +impl crate::Config { pub fn svtr() -> Self { Self::default() .with_name("svtr") diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs index 728c7bb..3ae1ccf 100644 --- a/src/models/svtr/impl.rs +++ b/src/models/svtr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Processor, Ts, Xs, Y}; +use crate::{elapsed, Config, DynConf, Engine, Image, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct SVTR { @@ -18,7 +18,7 @@ pub struct SVTR { } impl SVTR { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs index 291c8e6..3b6f6c9 100644 --- a/src/models/trocr/config.rs +++ b/src/models/trocr/config.rs @@ -1,7 +1,7 @@ use crate::Scale; /// Model configuration for `TrOCR` -impl crate::ModelConfig { +impl crate::Config { pub fn trocr() -> Self { Self::default() .with_name("trocr") diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs index 4a27fef..8811a07 100644 --- a/src/models/trocr/impl.rs +++ b/src/models/trocr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y}; +use crate::{elapsed, Config, Engine, Image, LogitsSampler, Processor, Scale, Ts, Xs, X, Y}; #[derive(Debug, Copy, Clone)] pub enum TrOCRKind { @@ -40,7 +40,7 @@ pub struct TrOCR { } impl TrOCR { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let encoder = Engine::try_from_config(&config.visual)?; let decoder = Engine::try_from_config(&config.textual_decoder)?; let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?; diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs index 6fe6078..1f240e9 100644 --- a/src/models/yolo/config.rs +++ b/src/models/yolo/config.rs @@ -1,10 +1,10 @@ use crate::{ - models::YOLOPredsFormat, ModelConfig, ResizeMode, Scale, Task, NAMES_COCO_80, + models::YOLOPredsFormat, Config, ResizeMode, Scale, Task, NAMES_COCO_80, NAMES_COCO_KEYPOINTS_17, NAMES_DOTA_V1_15, NAMES_IMAGENET_1K, NAMES_YOLOE_4585, NAMES_YOLO_DOCLAYOUT_10, }; -impl ModelConfig { +impl Config { pub fn yolo() -> Self { Self::default() .with_name("yolo") diff --git a/src/models/yolo/impl.rs b/src/models/yolo/impl.rs index ac05d38..5f56d11 100644 --- a/src/models/yolo/impl.rs +++ b/src/models/yolo/impl.rs @@ -8,8 +8,8 @@ use regex::Regex; use crate::{ elapsed, models::{BoxType, YOLOPredsFormat}, - DynConf, Engine, Hbb, Image, Keypoint, Mask, ModelConfig, NmsOps, Obb, Ops, Prob, Processor, - Task, Ts, Version, Xs, Y, + Config, DynConf, Engine, Hbb, Image, Keypoint, Mask, NmsOps, Obb, Ops, Prob, Processor, Task, + Ts, Version, Xs, Y, }; #[derive(Debug, Builder)] @@ -36,16 +36,16 @@ pub struct YOLO { classes_retained: Vec, } -impl TryFrom for YOLO { +impl TryFrom for YOLO { type Error = anyhow::Error; - fn try_from(config: ModelConfig) -> Result { + fn try_from(config: Config) -> Result { Self::new(config) } } impl YOLO { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts, spec) = ( engine.batch().opt(), @@ -204,7 +204,7 @@ impl YOLO { if nc == 0 && names.is_empty() { anyhow::bail!( "Neither class names nor the number of classes were specified. \ - \nConsider specify them with `ModelConfig::default().with_nc()` or `ModelConfig::default().with_class_names()`" + \nConsider specify them with `Config::default().with_nc()` or `Config::default().with_class_names()`" ); } @@ -226,7 +226,7 @@ impl YOLO { (true, Some(nk)) => nk, (true, None) => anyhow::bail!( "Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \ - \nConsider specify them with `ModelConfig::default().with_nk()` or `ModelConfig::default().with_keypoint_names()`" + \nConsider specify them with `Config::default().with_nk()` or `Config::default().with_keypoint_names()`" ), } } else { diff --git a/src/models/yolop/config.rs b/src/models/yolop/config.rs index 9736c7d..208d9dc 100644 --- a/src/models/yolop/config.rs +++ b/src/models/yolop/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `YOLOP` -impl crate::ModelConfig { +impl crate::Config { pub fn yolop() -> Self { Self::default() .with_name("yolop") diff --git a/src/models/yolop/impl.rs b/src/models/yolop/impl.rs index 717078f..6324516 100644 --- a/src/models/yolop/impl.rs +++ b/src/models/yolop/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Array, Axis, IxDyn}; use crate::{ - elapsed, DynConf, Engine, Hbb, Image, ModelConfig, NmsOps, Ops, Polygon, Processor, Ts, Xs, Y, + elapsed, Config, DynConf, Engine, Hbb, Image, NmsOps, Ops, Polygon, Processor, Ts, Xs, Y, }; #[derive(Builder, Debug)] @@ -20,7 +20,7 @@ pub struct YOLOPv2 { } impl YOLOPv2 { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: Config) -> Result { let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( diff --git a/src/inference/model_config.rs b/src/utils/config.rs similarity index 76% rename from src/inference/model_config.rs rename to src/utils/config.rs index 0944b53..3f9895f 100644 --- a/src/inference/model_config.rs +++ b/src/utils/config.rs @@ -1,14 +1,14 @@ use aksr::Builder; use crate::{ - impl_model_config_methods, impl_process_config_methods, + impl_ort_config_methods, impl_processor_config_methods, models::{SamKind, YOLOPredsFormat}, - EngineConfig, ProcessorConfig, Scale, Task, Version, + ORTConfig, ProcessorConfig, Scale, Task, Version, }; -/// ModelConfig for building models and inference +/// Config for building models and inference #[derive(Builder, Debug, Clone)] -pub struct ModelConfig { +pub struct Config { // Basics pub name: &'static str, pub version: Option, @@ -16,22 +16,22 @@ pub struct ModelConfig { pub scale: Option, // Engines - pub model: EngineConfig, - pub visual: EngineConfig, - pub textual: EngineConfig, - pub encoder: EngineConfig, - pub decoder: EngineConfig, - pub visual_encoder: EngineConfig, - pub textual_encoder: EngineConfig, - pub visual_decoder: EngineConfig, - pub textual_decoder: EngineConfig, - pub textual_decoder_merged: EngineConfig, - pub size_encoder: EngineConfig, - pub size_decoder: EngineConfig, - pub coord_encoder: EngineConfig, - pub coord_decoder: EngineConfig, - pub visual_projection: EngineConfig, - pub textual_projection: EngineConfig, + pub model: ORTConfig, + pub visual: ORTConfig, + pub textual: ORTConfig, + pub encoder: ORTConfig, + pub decoder: ORTConfig, + pub visual_encoder: ORTConfig, + pub textual_encoder: ORTConfig, + pub visual_decoder: ORTConfig, + pub textual_decoder: ORTConfig, + pub textual_decoder_merged: ORTConfig, + pub size_encoder: ORTConfig, + pub size_decoder: ORTConfig, + pub coord_encoder: ORTConfig, + pub coord_decoder: ORTConfig, + pub visual_projection: ORTConfig, + pub textual_projection: ORTConfig, // Processor pub processor: ProcessorConfig, @@ -65,7 +65,7 @@ pub struct ModelConfig { pub sam_low_res_mask: Option, } -impl Default for ModelConfig { +impl Default for Config { fn default() -> Self { Self { class_names: vec![], @@ -116,7 +116,7 @@ impl Default for ModelConfig { } } -impl ModelConfig { +impl Config { pub fn exclude_classes(mut self, xs: &[usize]) -> Self { self.classes_retained.clear(); self.classes_excluded.extend_from_slice(xs); @@ -147,7 +147,7 @@ impl ModelConfig { self.model.file = y; } - fn try_commit(name: &str, mut m: EngineConfig) -> anyhow::Result { + fn try_commit(name: &str, mut m: ORTConfig) -> anyhow::Result { if !m.file.is_empty() { m = m.try_commit(name)?; return Ok(m); @@ -176,6 +176,27 @@ impl ModelConfig { Ok(self) } + pub fn with_num_dry_run_all(mut self, x: usize) -> Self { + self.visual = self.visual.with_num_dry_run(x); + self.textual = self.textual.with_num_dry_run(x); + self.model = self.model.with_num_dry_run(x); + self.encoder = self.encoder.with_num_dry_run(x); + self.decoder = self.decoder.with_num_dry_run(x); + self.visual_encoder = self.visual_encoder.with_num_dry_run(x); + self.textual_encoder = self.textual_encoder.with_num_dry_run(x); + self.visual_decoder = self.visual_decoder.with_num_dry_run(x); + self.textual_decoder = self.textual_decoder.with_num_dry_run(x); + self.textual_decoder_merged = self.textual_decoder_merged.with_num_dry_run(x); + self.size_encoder = self.size_encoder.with_num_dry_run(x); + self.size_decoder = self.size_decoder.with_num_dry_run(x); + self.coord_encoder = self.coord_encoder.with_num_dry_run(x); + self.coord_decoder = self.coord_decoder.with_num_dry_run(x); + self.visual_projection = self.visual_projection.with_num_dry_run(x); + self.textual_projection = self.textual_projection.with_num_dry_run(x); + + self + } + pub fn with_batch_size_all(mut self, batch_size: usize) -> Self { self.visual = self.visual.with_ixx(0, 0, batch_size.into()); self.textual = self.textual.with_ixx(0, 0, batch_size.into()); @@ -242,20 +263,20 @@ impl ModelConfig { } } -impl_model_config_methods!(ModelConfig, model); -impl_model_config_methods!(ModelConfig, visual); -impl_model_config_methods!(ModelConfig, textual); -impl_model_config_methods!(ModelConfig, encoder); -impl_model_config_methods!(ModelConfig, decoder); -impl_model_config_methods!(ModelConfig, visual_encoder); -impl_model_config_methods!(ModelConfig, textual_encoder); -impl_model_config_methods!(ModelConfig, visual_decoder); -impl_model_config_methods!(ModelConfig, textual_decoder); -impl_model_config_methods!(ModelConfig, textual_decoder_merged); -impl_model_config_methods!(ModelConfig, size_encoder); -impl_model_config_methods!(ModelConfig, size_decoder); -impl_model_config_methods!(ModelConfig, coord_encoder); -impl_model_config_methods!(ModelConfig, coord_decoder); -impl_model_config_methods!(ModelConfig, visual_projection); -impl_model_config_methods!(ModelConfig, textual_projection); -impl_process_config_methods!(ModelConfig, processor); +impl_ort_config_methods!(Config, model); +impl_ort_config_methods!(Config, visual); +impl_ort_config_methods!(Config, textual); +impl_ort_config_methods!(Config, encoder); +impl_ort_config_methods!(Config, decoder); +impl_ort_config_methods!(Config, visual_encoder); +impl_ort_config_methods!(Config, textual_encoder); +impl_ort_config_methods!(Config, visual_decoder); +impl_ort_config_methods!(Config, textual_decoder); +impl_ort_config_methods!(Config, textual_decoder_merged); +impl_ort_config_methods!(Config, size_encoder); +impl_ort_config_methods!(Config, size_decoder); +impl_ort_config_methods!(Config, coord_encoder); +impl_ort_config_methods!(Config, coord_decoder); +impl_ort_config_methods!(Config, visual_projection); +impl_ort_config_methods!(Config, textual_projection); +impl_processor_config_methods!(Config, processor); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 008dfb6..e145054 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,4 @@ +mod config; mod device; mod dtype; mod dynconf; @@ -6,6 +7,7 @@ mod logits_sampler; mod min_opt_max; mod names; mod ops; +mod ort_config; mod processor; mod processor_config; mod retry; @@ -15,6 +17,7 @@ mod traits; mod ts; mod version; +pub use config::*; pub use device::Device; pub use dtype::DType; pub use dynconf::DynConf; @@ -23,6 +26,7 @@ pub use logits_sampler::LogitsSampler; pub use min_opt_max::MinOptMax; pub use names::*; pub use ops::*; +pub use ort_config::ORTConfig; pub use processor::*; pub use processor_config::ProcessorConfig; pub use scale::Scale; diff --git a/src/utils/ops.rs b/src/utils/ops.rs index 9f54162..b29e35f 100644 --- a/src/utils/ops.rs +++ b/src/utils/ops.rs @@ -403,7 +403,7 @@ impl Ops<'_> { } "height" => (th * w0 / h0, th), "width" => (tw, tw * h0 / w0), - _ => anyhow::bail!("EngineConfig for `letterbox`: width, height, auto"), + _ => anyhow::bail!("ORTConfig for `letterbox`: width, height, auto"), }; let mut dst = Image::from_vec_u8( diff --git a/src/inference/engine_config.rs b/src/utils/ort_config.rs similarity index 87% rename from src/inference/engine_config.rs rename to src/utils/ort_config.rs index c63f251..d1c0715 100644 --- a/src/inference/engine_config.rs +++ b/src/utils/ort_config.rs @@ -3,19 +3,34 @@ use anyhow::Result; use crate::{try_fetch_file_stem, DType, Device, Hub, Iiix, MinOptMax}; -#[derive(Builder, Debug, Clone, Default)] -pub struct EngineConfig { +#[derive(Builder, Debug, Clone)] +pub struct ORTConfig { pub file: String, pub device: Device, pub iiixs: Vec, pub num_dry_run: usize, pub trt_fp16: bool, - pub ort_graph_opt_level: Option, + pub graph_opt_level: Option, pub spec: String, // TODO: move out pub dtype: DType, // For dynamically loading the model } -impl EngineConfig { +impl Default for ORTConfig { + fn default() -> Self { + Self { + file: Default::default(), + device: Default::default(), + iiixs: Default::default(), + graph_opt_level: Default::default(), + spec: Default::default(), + dtype: Default::default(), + num_dry_run: 3, + trt_fp16: true, + } + } +} + +impl ORTConfig { pub fn try_commit(mut self, name: &str) -> Result { // Identify the local model or fetch the remote model if std::path::PathBuf::from(&self.file).exists() { @@ -65,7 +80,7 @@ impl EngineConfig { } } -impl EngineConfig { +impl ORTConfig { pub fn with_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { self.iiixs.push(Iiix::from((i, ii, x))); self @@ -78,7 +93,7 @@ impl EngineConfig { } #[macro_export] -macro_rules! impl_model_config_methods { +macro_rules! impl_ort_config_methods { ($ty:ty, $field:ident) => { impl $ty { paste::paste! { diff --git a/src/utils/processor_config.rs b/src/utils/processor_config.rs index d6a3daf..6a5898e 100644 --- a/src/utils/processor_config.rs +++ b/src/utils/processor_config.rs @@ -151,7 +151,7 @@ impl ProcessorConfig { } #[macro_export] -macro_rules! impl_process_config_methods { +macro_rules! impl_processor_config_methods { ($ty:ty, $field:ident) => { impl $ty { pub fn with_image_width(mut self, image_width: u32) -> Self {