From 70e6b2f03e8a1399f8c504cf811c91de88451e79 Mon Sep 17 00:00:00 2001 From: jamjamjon Date: Fri, 16 May 2025 15:45:27 +0800 Subject: [PATCH] Options -> ModelConfig --- Cargo.toml | 3 +- README.md | 3 +- examples/ben2/main.rs | 6 +- examples/blip/main.rs | 11 +- examples/classifier/main.rs | 18 +- examples/clip/main.rs | 15 +- examples/d-fine/main.rs | 7 +- examples/db/main.rs | 18 +- examples/deim/main.rs | 7 +- examples/depth-anything/main.rs | 5 +- examples/depth-pro/main.rs | 6 +- examples/dinov2/main.rs | 8 +- examples/doclayout-yolo/main.rs | 4 +- examples/fast/main.rs | 16 +- examples/fastsam/main.rs | 4 +- examples/florence2/README.md | 2 +- examples/florence2/main.rs | 59 +-- examples/grounding-dino/main.rs | 8 +- examples/linknet/main.rs | 12 +- examples/modnet/main.rs | 5 +- examples/moondream2/main.rs | 90 +---- examples/owlv2/main.rs | 8 +- examples/picodet-layout/main.rs | 13 +- examples/rfdetr/main.rs | 7 +- examples/rmbg/main.rs | 14 +- examples/rtdetr/main.rs | 19 +- examples/rtmo/main.rs | 4 +- examples/sam/main.rs | 44 +-- examples/sam2/main.rs | 37 +- examples/sapiens/main.rs | 6 +- examples/slanet/main.rs | 6 +- examples/smolvlm/main.rs | 35 +- examples/svtr/main.rs | 6 +- examples/trocr/main.rs | 62 +-- examples/yolo-sam2/main.rs | 15 +- examples/yolo/README.md | 2 +- examples/yolo/main.rs | 31 +- examples/yoloe/README.md | 2 +- examples/yoloe/main.rs | 8 +- examples/yolop/main.rs | 5 +- examples/yolov8-rtdetr/main.rs | 4 +- src/inference/engine.rs | 18 +- src/inference/engine_config.rs | 112 ++++++ src/inference/image.rs | 6 +- src/inference/mod.rs | 4 + src/inference/model_config.rs | 243 ++++++++++++ src/io/dataloader.rs | 10 +- src/models/beit/config.rs | 8 +- src/models/ben2/config.rs | 2 +- src/models/blip/config.rs | 42 +- src/models/blip/impl.rs | 56 +-- src/models/clip/config.rs | 90 ++--- src/models/clip/impl.rs | 142 ++----- src/models/convnext/config.rs | 4 +- src/models/d_fine/config.rs | 4 +- src/models/db/config.rs | 6 +- src/models/db/impl.rs | 20 +- src/models/deim/config.rs | 4 +- src/models/deit/config.rs | 4 +- src/models/depth_anything/config.rs | 14 +- src/models/depth_anything/impl.rs | 10 +- src/models/depth_pro/config.rs | 16 +- src/models/depth_pro/impl.rs | 9 +- src/models/dinov2/config.rs | 8 +- src/models/dinov2/impl.rs | 11 +- src/models/fast/config.rs | 4 +- src/models/fastvit/config.rs | 4 +- src/models/florence2/config.rs | 68 +--- src/models/florence2/impl.rs | 89 +++-- src/models/grounding_dino/config.rs | 12 +- src/models/grounding_dino/impl.rs | 23 +- src/models/linknet/config.rs | 4 +- src/models/mobileone/config.rs | 4 +- src/models/modnet/config.rs | 4 +- src/models/modnet/impl.rs | 9 +- src/models/moondream2/config.rs | 144 ++----- src/models/moondream2/impl.rs | 212 +++------- src/models/owl/config.rs | 8 +- src/models/owl/impl.rs | 17 +- src/models/picodet/config.rs | 6 +- src/models/picodet/impl.rs | 17 +- src/models/pipeline/basemodel.rs | 46 +-- src/models/pipeline/image_classifier.rs | 23 +- src/models/rfdetr/config.rs | 9 +- src/models/rfdetr/impl.rs | 22 +- src/models/rmbg/config.rs | 5 +- src/models/rmbg/impl.rs | 10 +- src/models/rtdetr/config.rs | 8 +- src/models/rtdetr/impl.rs | 23 +- src/models/rtmo/config.rs | 5 +- src/models/rtmo/impl.rs | 17 +- src/models/sam/config.rs | 107 ++---- src/models/sam/impl.rs | 28 +- src/models/sam2/config.rs | 54 +-- src/models/sam2/impl.rs | 14 +- src/models/sapiens/config.rs | 27 +- src/models/sapiens/impl.rs | 14 +- src/models/slanet/config.rs | 8 +- src/models/slanet/impl.rs | 6 +- src/models/smolvlm/config.rs | 62 +-- src/models/smolvlm/impl.rs | 108 ++---- src/models/svtr/config.rs | 8 +- src/models/svtr/impl.rs | 13 +- src/models/trocr/config.rs | 108 ++---- src/models/trocr/impl.rs | 241 +++--------- src/models/yolo/config.rs | 242 ++++-------- src/models/yolo/impl.rs | 56 +-- src/models/yolop/config.rs | 6 +- src/models/yolop/impl.rs | 15 +- src/utils/device.rs | 9 +- src/utils/dtype.rs | 3 +- src/utils/kind.rs | 18 - src/utils/mod.rs | 6 +- src/utils/ops.rs | 30 +- src/utils/options.rs | 488 ------------------------ src/utils/processor.rs | 49 ++- src/utils/processor_config.rs | 245 ++++++++++++ src/viz/annotator.rs | 2 +- 118 files changed, 1735 insertions(+), 2503 deletions(-) create mode 100644 src/inference/engine_config.rs create mode 100644 src/inference/model_config.rs delete mode 100644 src/utils/kind.rs delete mode 100644 src/utils/options.rs create mode 100644 src/utils/processor_config.rs diff --git a/Cargo.toml b/Cargo.toml index 1cd2b59..b4146ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "usls" edition = "2021" -version = "0.1.0-beta.1" +version = "0.1.0-beta.2" rust-version = "1.82" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" @@ -44,6 +44,7 @@ ort = { version = "2.0.0-rc.9", default-features = false, optional = true , feat "half" ]} tokenizers = { version = "0.21.1" } +paste = "1.0.15" [build-dependencies] prost-build = "0.13.5" diff --git a/README.md b/README.md index f03afd1..cb658fb 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,8 @@ | [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | | | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | | | [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | | -| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation Answering | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | | +| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Erase | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | | +| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Erase | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | | diff --git a/examples/ben2/main.rs b/examples/ben2/main.rs index 9096318..5d0375c 100644 --- a/examples/ben2/main.rs +++ b/examples/ben2/main.rs @@ -1,4 +1,4 @@ -use usls::{models::RMBG, Annotator, DataLoader, Options}; +use usls::{models::RMBG, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -20,11 +20,11 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); // build model - let options = Options::ben2_base() + let config = ModelConfig::ben2_base() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = RMBG::new(options)?; + let mut model = RMBG::new(config)?; // load image let xs = DataLoader::try_read_n(&["./assets/cat.png"])?; diff --git a/examples/blip/main.rs b/examples/blip/main.rs index a04bfb8..6367a67 100644 --- a/examples/blip/main.rs +++ b/examples/blip/main.rs @@ -1,4 +1,4 @@ -use usls::{models::Blip, DataLoader, Options}; +use usls::{models::Blip, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// BLIP Example @@ -20,13 +20,10 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); // build model - let options_visual = Options::blip_v1_base_caption_visual() - .with_model_device(args.device.as_str().try_into()?) + let config = ModelConfig::blip_v1_base_caption() + .with_device_all(args.device.as_str().try_into()?) .commit()?; - let options_textual = Options::blip_v1_base_caption_textual() - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let mut model = Blip::new(options_visual, options_textual)?; + let mut model = Blip::new(config)?; // image caption let xs = DataLoader::try_read_n(&args.source)?; diff --git a/examples/classifier/main.rs b/examples/classifier/main.rs index fe9536e..73245ff 100644 --- a/examples/classifier/main.rs +++ b/examples/classifier/main.rs @@ -1,4 +1,4 @@ -use usls::{models::ImageClassifier, Annotator, DataLoader, Options}; +use usls::{models::ImageClassifier, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -36,20 +36,20 @@ fn main() -> anyhow::Result<()> { let args: Args = argh::from_env(); // build model - let options = match args.model.to_lowercase().as_str() { - "beit" => Options::beit_base(), - "convnext" => Options::convnext_v2_atto(), - "deit" => Options::deit_tiny_distill(), - "fastvit" => Options::fastvit_t8_distill(), - "mobileone" => Options::mobileone_s0(), + let config = match args.model.to_lowercase().as_str() { + "beit" => ModelConfig::beit_base(), + "convnext" => ModelConfig::convnext_v2_atto(), + "deit" => ModelConfig::deit_tiny_distill(), + "fastvit" => ModelConfig::fastvit_t8_distill(), + "mobileone" => ModelConfig::mobileone_s0(), _ => anyhow::bail!("Unsupported model: {}", args.model), }; - let options = options + let config = config .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = ImageClassifier::try_from(options)?; + let mut model = ImageClassifier::try_from(config)?; // load images let xs = DataLoader::try_read_n(&args.source)?; diff --git a/examples/clip/main.rs b/examples/clip/main.rs index a600913..c650d00 100644 --- a/examples/clip/main.rs +++ b/examples/clip/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Clip, DataLoader, Ops, Options}; +use usls::{models::Clip, DataLoader, ModelConfig, Ops}; #[derive(argh::FromArgs)] /// CLIP Example @@ -14,18 +14,13 @@ fn main() -> Result<()> { .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - let args: Args = argh::from_env(); + // build model - let options_visual = Options::jina_clip_v1_visual() - // clip_vit_b32_visual() - .with_model_device(args.device.as_str().try_into()?) + let config = ModelConfig::jina_clip_v1() + .with_device_all(args.device.as_str().try_into()?) .commit()?; - let options_textual = Options::jina_clip_v1_textual() - // clip_vit_b32_textual() - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let mut model = Clip::new(options_visual, options_textual)?; + let mut model = Clip::new(config)?; // texts let texts = vec![ diff --git a/examples/d-fine/main.rs b/examples/d-fine/main.rs index b74e3dc..ffdfaca 100644 --- a/examples/d-fine/main.rs +++ b/examples/d-fine/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, Options}; +use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -7,9 +7,8 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - // options - let options = Options::d_fine_n_coco().commit()?; - let mut model = RTDETR::new(options)?; + // config + let mut model = RTDETR::new(ModelConfig::d_fine_n_coco().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/db/main.rs b/examples/db/main.rs index e205c35..2f92e65 100644 --- a/examples/db/main.rs +++ b/examples/db/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DB, Annotator, DataLoader, Options, Style}; +use usls::{models::DB, Annotator, DataLoader, ModelConfig, Style}; #[derive(argh::FromArgs)] /// Example @@ -41,15 +41,13 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let options = match &args.model { - Some(m) => Options::db().with_model_file(m), - None => Options::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?), - }; - let mut model = DB::new( - options - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - )?; + let config = match &args.model { + Some(m) => ModelConfig::db().with_model_file(m), + None => ModelConfig::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?), + } + .with_device_all(args.device.as_str().try_into()?) + .commit()?; + let mut model = DB::new(config)?; // load image let xs = DataLoader::try_read_n(&[ diff --git a/examples/deim/main.rs b/examples/deim/main.rs index b253394..6fc358f 100644 --- a/examples/deim/main.rs +++ b/examples/deim/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, Options}; +use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -7,9 +7,8 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - // options - let options = Options::deim_dfine_s_coco().commit()?; - let mut model = RTDETR::new(options)?; + // config + let mut model = RTDETR::new(ModelConfig::deim_dfine_s_coco().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index f5b1d3e..3981ce2 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DepthAnything, Annotator, DataLoader, Options, Style}; +use usls::{models::DepthAnything, Annotator, DataLoader, ModelConfig, Style}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,8 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let options = Options::depth_anything_v2_small().commit()?; - let mut model = DepthAnything::new(options)?; + let mut model = DepthAnything::new(ModelConfig::depth_anything_v2_small().commit()?)?; // load let xs = DataLoader::try_read_n(&["images/street.jpg"])?; diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs index 929fa53..8c0f559 100644 --- a/examples/depth-pro/main.rs +++ b/examples/depth-pro/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::DepthPro, Annotator, Options, Style}; +use usls::{models::DepthPro, Annotator, ModelConfig, Style}; #[derive(argh::FromArgs)] /// Example @@ -23,11 +23,11 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // model - let options = Options::depth_pro() + let config = ModelConfig::depth_pro() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = DepthPro::new(options)?; + let mut model = DepthPro::new(config)?; // load let xs = DataLoader::try_read_n(&["images/street.jpg"])?; diff --git a/examples/dinov2/main.rs b/examples/dinov2/main.rs index e39427b..f9eda41 100644 --- a/examples/dinov2/main.rs +++ b/examples/dinov2/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DINOv2, DataLoader, Options}; +use usls::{models::DINOv2, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -11,8 +11,10 @@ fn main() -> Result<()> { let xs = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/bus.jpg"])?; // model - let options = Options::dinov2_small().with_batch_size(xs.len()).commit()?; - let mut model = DINOv2::new(options)?; + let config = ModelConfig::dinov2_small() + .with_batch_size_all(xs.len()) + .commit()?; + let mut model = DINOv2::new(config)?; // encode images let y = model.encode_images(&xs)?; diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs index b6ab0d8..3b2c467 100644 --- a/examples/doclayout-yolo/main.rs +++ b/examples/doclayout-yolo/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, Options}; +use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -18,7 +18,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = Options::doclayout_yolo_docstructbench() + let config = ModelConfig::doclayout_yolo_docstructbench() .with_model_device(args.device.as_str().try_into()?) .commit()?; let mut model = YOLO::new(config)?; diff --git a/examples/fast/main.rs b/examples/fast/main.rs index 9304c9d..493f737 100644 --- a/examples/fast/main.rs +++ b/examples/fast/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::DB, Annotator, DataLoader, Options, Scale, Style}; +use usls::{models::DB, Annotator, DataLoader, ModelConfig, Scale, Style}; #[derive(argh::FromArgs)] /// Example @@ -26,16 +26,16 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let options = match args.scale.as_str().try_into()? { - Scale::T => Options::fast_tiny(), - Scale::S => Options::fast_small(), - Scale::B => Options::fast_base(), + let config = match args.scale.as_str().try_into()? { + Scale::T => ModelConfig::fast_tiny(), + Scale::S => ModelConfig::fast_small(), + Scale::B => ModelConfig::fast_base(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }; let mut model = DB::new( - options - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) + config + .with_dtype_all(args.dtype.as_str().try_into()?) + .with_device_all(args.device.as_str().try_into()?) .commit()?, )?; diff --git a/examples/fastsam/main.rs b/examples/fastsam/main.rs index 29e2c1e..47a54a2 100644 --- a/examples/fastsam/main.rs +++ b/examples/fastsam/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, Options}; +use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = Options::fastsam_s() + let config = ModelConfig::fastsam_s() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/examples/florence2/README.md b/examples/florence2/README.md index 6078515..377cb5d 100644 --- a/examples/florence2/README.md +++ b/examples/florence2/README.md @@ -1,7 +1,7 @@ ## Quick Start ```shell -cargo run -r -F cuda --example florence2 -- --device cuda --scale base --dtype fp16 +cargo run -r -F cuda --example florence2 -- --device cuda --dtype fp16 ``` diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 4d6bde0..63bf33c 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -1,20 +1,16 @@ use anyhow::Result; -use usls::{models::Florence2, Annotator, DataLoader, Options, Scale, Style, Task}; +use usls::{models::Florence2, Annotator, DataLoader, ModelConfig, Style, Task}; #[derive(argh::FromArgs)] /// Example struct Args { /// dtype - #[argh(option, default = "String::from(\"auto\")")] + #[argh(option, default = "String::from(\"fp16\")")] dtype: String, /// device #[argh(option, default = "String::from(\"cpu:0\")")] device: String, - - /// scale - #[argh(option, default = "String::from(\"base\")")] - scale: String, } fn main() -> Result<()> { @@ -29,51 +25,12 @@ fn main() -> Result<()> { let xs = DataLoader::try_read_n(&["images/green-car.jpg", "assets/bus.jpg"])?; // build model - let ( - options_vision_encoder, - options_text_embed, - options_encoder, - options_decoder, - options_decoder_merged, - ) = match args.scale.as_str().try_into()? { - Scale::B => ( - Options::florence2_visual_encoder_base(), - Options::florence2_textual_embed_base(), - Options::florence2_texual_encoder_base(), - Options::florence2_texual_decoder_base(), - Options::florence2_texual_decoder_merged_base(), - ), - Scale::L => todo!(), - _ => anyhow::bail!("Unsupported Florence2 scale."), - }; - - let mut model = Florence2::new( - options_vision_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_text_embed - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_decoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_decoder_merged - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - )?; + let config = ModelConfig::florence2_base() + .with_dtype_all(args.dtype.as_str().try_into()?) + .with_device_all(args.device.as_str().try_into()?) + .with_batch_size_all(xs.len()) + .commit()?; + let mut model = Florence2::new(config)?; // tasks let tasks = [ diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index e80083f..e155428 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -1,11 +1,11 @@ use anyhow::Result; -use usls::{models::GroundingDINO, Annotator, DataLoader, Options}; +use usls::{models::GroundingDINO, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example struct Args { /// dtype - #[argh(option, default = "String::from(\"auto\")")] + #[argh(option, default = "String::from(\"fp16\")")] dtype: String, /// device @@ -45,7 +45,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); - let options = Options::grounding_dino_tiny() + let config = ModelConfig::grounding_dino_tiny() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) @@ -53,7 +53,7 @@ fn main() -> Result<()> { .with_text_confs(&[0.25]) .commit()?; - let mut model = GroundingDINO::new(options)?; + let mut model = GroundingDINO::new(config)?; // load images let xs = DataLoader::try_read_n(&args.source)?; diff --git a/examples/linknet/main.rs b/examples/linknet/main.rs index 3dfd21a..f7c85ad 100644 --- a/examples/linknet/main.rs +++ b/examples/linknet/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::DB, Annotator, Options, Scale, Style}; +use usls::{models::DB, Annotator, ModelConfig, Scale, Style}; #[derive(argh::FromArgs)] /// Example @@ -27,14 +27,14 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let options = match args.scale.as_str().try_into()? { - Scale::T => Options::linknet_r18(), - Scale::S => Options::linknet_r34(), - Scale::B => Options::linknet_r50(), + let config = match args.scale.as_str().try_into()? { + Scale::T => ModelConfig::linknet_r18(), + Scale::S => ModelConfig::linknet_r34(), + Scale::B => ModelConfig::linknet_r50(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }; let mut model = DB::new( - options + config .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?, diff --git a/examples/modnet/main.rs b/examples/modnet/main.rs index b0ee231..9dde076 100644 --- a/examples/modnet/main.rs +++ b/examples/modnet/main.rs @@ -1,4 +1,4 @@ -use usls::{models::MODNet, Annotator, DataLoader, Options}; +use usls::{models::MODNet, Annotator, DataLoader, ModelConfig}; fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() @@ -7,8 +7,7 @@ fn main() -> anyhow::Result<()> { .init(); // build model - let options = Options::modnet_photographic().commit()?; - let mut model = MODNet::new(options)?; + let mut model = MODNet::new(ModelConfig::modnet_photographic().commit()?)?; // load image let xs = DataLoader::try_read_n(&["images/liuyifei.png"])?; diff --git a/examples/moondream2/main.rs b/examples/moondream2/main.rs index 963197a..907a9aa 100644 --- a/examples/moondream2/main.rs +++ b/examples/moondream2/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Moondream2, Annotator, DataLoader, Options, Scale, Task}; +use usls::{models::Moondream2, Annotator, DataLoader, ModelConfig, Scale, Task}; #[derive(argh::FromArgs)] /// Example @@ -39,81 +39,16 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let ( - options_vision_encoder, - options_vision_projection, - options_text_decoder, - options_text_encoder, - options_coord_decoder, - options_coord_encoder, - options_size_decoder, - options_size_encoder, - ) = match args.scale.as_str().try_into()? { - Scale::Billion(2.) => ( - Options::moondream2_2b_vision_encoder(), - Options::moondream2_2b_vision_projection(), - Options::moondream2_2b_text_decoder(), - Options::moondream2_2b_text_encoder(), - Options::moondream2_2b_coord_decoder(), - Options::moondream2_2b_coord_encoder(), - Options::moondream2_2b_size_decoder(), - Options::moondream2_2b_size_encoder(), - ), - Scale::Billion(0.5) => ( - Options::moondream2_0_5b_vision_encoder(), - Options::moondream2_0_5b_vision_projection(), - Options::moondream2_0_5b_text_decoder(), - Options::moondream2_0_5b_text_encoder(), - Options::moondream2_0_5b_coord_decoder(), - Options::moondream2_0_5b_coord_encoder(), - Options::moondream2_0_5b_size_decoder(), - Options::moondream2_0_5b_size_encoder(), - ), + let config = match args.scale.as_str().try_into()? { + Scale::Billion(0.5) => ModelConfig::moondream2_0_5b(), + Scale::Billion(2.) => ModelConfig::moondream2_2b(), _ => unimplemented!(), - }; + } + .with_dtype_all(args.dtype.as_str().try_into()?) + .with_device_all(args.device.as_str().try_into()?) + .commit()?; - let mut model = Moondream2::new( - options_vision_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - options_vision_projection - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - options_text_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - options_text_decoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - Some( - options_coord_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - ), - Some( - options_coord_decoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - ), - Some( - options_size_encoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - ), - Some( - options_size_decoder - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - ), - )?; + let mut model = Moondream2::new(config)?; // load images let xs = DataLoader::try_read_n(&args.source)?; @@ -142,13 +77,6 @@ fn main() -> Result<()> { } Task::OpenSetDetection(_) | Task::OpenSetKeypointsDetection(_) => { println!("{:?}", ys); - // let annotator = Annotator::default() - // .with_bboxes_thickness(4) - // .without_bboxes_conf(true) - // .with_keypoints_radius(6) - // .with_keypoints_name(true) - // .with_saveout("moondream2"); - // annotator.annotate(&xs, &ys); // annotate let annotator = Annotator::default() diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs index f717e00..a0f9a08 100644 --- a/examples/owlv2/main.rs +++ b/examples/owlv2/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::OWLv2, Annotator, Options}; +use usls::{models::OWLv2, Annotator, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -46,14 +46,14 @@ fn main() -> Result<()> { .init(); let args: Args = argh::from_env(); - // options - let options = Options::owlv2_base_ensemble() + // config + let config = ModelConfig::owlv2_base_ensemble() // owlv2_base() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) .commit()?; - let mut model = OWLv2::new(options)?; + let mut model = OWLv2::new(config)?; // load let xs = DataLoader::try_read_n(&args.source)?; diff --git a/examples/picodet-layout/main.rs b/examples/picodet-layout/main.rs index baa9bf4..5034a54 100644 --- a/examples/picodet-layout/main.rs +++ b/examples/picodet-layout/main.rs @@ -1,6 +1,6 @@ use anyhow::Result; use usls::DataLoader; -use usls::{models::PicoDet, Annotator, Options}; +use usls::{models::PicoDet, Annotator, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,12 +8,11 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - // options - let options = Options::picodet_layout_1x() - // picodet_l_layout_3cls() - // picodet_l_layout_17cls() - .commit()?; - let mut model = PicoDet::new(options)?; + // config + let config = ModelConfig::picodet_layout_1x().commit()?; + // picodet_l_layout_3cls() + // picodet_l_layout_17cls() + let mut model = PicoDet::new(config)?; // load let xs = DataLoader::try_read_n(&["images/academic.jpg"])?; diff --git a/examples/rfdetr/main.rs b/examples/rfdetr/main.rs index e034e45..eb39ef8 100644 --- a/examples/rfdetr/main.rs +++ b/examples/rfdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RFDETR, Annotator, DataLoader, Options}; +use usls::{models::RFDETR, Annotator, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -7,9 +7,8 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - // options - let options = Options::rfdetr_base().commit()?; - let mut model = RFDETR::new(options)?; + // config + let mut model = RFDETR::new(ModelConfig::rfdetr_base().commit()?)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/rmbg/main.rs b/examples/rmbg/main.rs index cb34be0..4fbdfbd 100644 --- a/examples/rmbg/main.rs +++ b/examples/rmbg/main.rs @@ -1,10 +1,10 @@ -use usls::{models::RMBG, Annotator, DataLoader, Options}; +use usls::{models::RMBG, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example struct Args { /// dtype - #[argh(option, default = "String::from(\"auto\")")] + #[argh(option, default = "String::from(\"fp16\")")] dtype: String, /// device @@ -23,18 +23,18 @@ fn main() -> anyhow::Result<()> { .init(); let args: Args = argh::from_env(); - let options = match args.ver { - 1.4 => Options::rmbg1_4(), - 2.0 => Options::rmbg2_0(), + let config = match args.ver { + 1.4 => ModelConfig::rmbg1_4(), + 2.0 => ModelConfig::rmbg2_0(), _ => unreachable!("Unsupported version"), }; // build model - let options = options + let config = config .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = RMBG::new(options)?; + let mut model = RMBG::new(config)?; // load image let xs = DataLoader::try_read_n(&["./assets/cat.png"])?; diff --git a/examples/rtdetr/main.rs b/examples/rtdetr/main.rs index 1d3aa5e..8a19510 100644 --- a/examples/rtdetr/main.rs +++ b/examples/rtdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTDETR, Annotator, DataLoader, Options}; +use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -7,15 +7,14 @@ fn main() -> Result<()> { .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .init(); - // options - let options = Options::rtdetr_v2_s_coco() - // rtdetr_v1_r18vd_coco() - // rtdetr_v2_ms_coco() - // rtdetr_v2_m_coco() - // rtdetr_v2_l_coco() - // rtdetr_v2_x_coco() - .commit()?; - let mut model = RTDETR::new(options)?; + // config + let config = ModelConfig::rtdetr_v2_s_coco().commit()?; + // rtdetr_v1_r18vd_coco() + // rtdetr_v2_ms_coco() + // rtdetr_v2_m_coco() + // rtdetr_v2_l_coco() + // rtdetr_v2_x_coco() + let mut model = RTDETR::new(config)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index 314e6fc..c2dde49 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::RTMO, Annotator, DataLoader, Options, Style, SKELETON_COCO_19}; +use usls::{models::RTMO, Annotator, DataLoader, ModelConfig, Style, SKELETON_COCO_19}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,7 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let mut model = RTMO::new(Options::rtmo_s().commit()?)?; + let mut model = RTMO::new(ModelConfig::rtmo_s().commit()?)?; // load image let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 21e4fcc..0da9d7f 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamKind, SamPrompt, SAM}, - Annotator, DataLoader, Options, Scale, + Annotator, DataLoader, ModelConfig, Scale, }; #[derive(argh::FromArgs)] @@ -28,40 +28,22 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // Build model - let (options_encoder, options_decoder) = match args.kind.as_str().try_into()? { - SamKind::Sam => ( - Options::sam_v1_base_encoder(), - Options::sam_v1_base_decoder(), - ), + let config = match args.kind.as_str().try_into()? { + SamKind::Sam => ModelConfig::sam_v1_base(), SamKind::Sam2 => match args.scale.as_str().try_into()? { - Scale::T => (Options::sam2_tiny_encoder(), Options::sam2_tiny_decoder()), - Scale::S => (Options::sam2_small_encoder(), Options::sam2_small_decoder()), - Scale::B => ( - Options::sam2_base_plus_encoder(), - Options::sam2_base_plus_decoder(), - ), + Scale::T => ModelConfig::sam2_tiny(), + Scale::S => ModelConfig::sam2_small(), + Scale::B => ModelConfig::sam2_base_plus(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale), }, + SamKind::MobileSam => ModelConfig::mobile_sam_tiny(), + SamKind::SamHq => ModelConfig::sam_hq_tiny(), + SamKind::EdgeSam => ModelConfig::edge_sam_3x(), + } + .with_device_all(args.device.as_str().try_into()?) + .commit()?; - SamKind::MobileSam => ( - Options::mobile_sam_tiny_encoder(), - Options::mobile_sam_tiny_decoder(), - ), - SamKind::SamHq => ( - Options::sam_hq_tiny_encoder(), - Options::sam_hq_tiny_decoder(), - ), - SamKind::EdgeSam => ( - Options::edge_sam_3x_encoder(), - Options::edge_sam_3x_decoder(), - ), - }; - - let options_encoder = options_encoder - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let options_decoder = options_decoder.commit()?; - let mut model = SAM::new(options_encoder, options_decoder)?; + let mut model = SAM::new(config)?; // Load image let xs = DataLoader::try_read_n(&["images/truck.jpg"])?; diff --git a/examples/sam2/main.rs b/examples/sam2/main.rs index e8722a4..79469f9 100644 --- a/examples/sam2/main.rs +++ b/examples/sam2/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamPrompt, SAM2}, - Annotator, DataLoader, Options, Scale, + Annotator, DataLoader, ModelConfig, Scale, }; #[derive(argh::FromArgs)] @@ -25,33 +25,16 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // Build model - let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? { - Scale::T => ( - Options::sam2_1_tiny_encoder(), - Options::sam2_1_tiny_decoder(), - ), - Scale::S => ( - Options::sam2_1_small_encoder(), - Options::sam2_1_small_decoder(), - ), - Scale::B => ( - Options::sam2_1_base_plus_encoder(), - Options::sam2_1_base_plus_decoder(), - ), - Scale::L => ( - Options::sam2_1_large_encoder(), - Options::sam2_1_large_decoder(), - ), + let config = match args.scale.as_str().try_into()? { + Scale::T => ModelConfig::sam2_1_tiny(), + Scale::S => ModelConfig::sam2_1_small(), + Scale::B => ModelConfig::sam2_1_base_plus(), + Scale::L => ModelConfig::sam2_1_large(), _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale), - }; - - let options_encoder = options_encoder - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let options_decoder = options_decoder - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let mut model = SAM2::new(options_encoder, options_decoder)?; + } + .with_device_all(args.device.as_str().try_into()?) + .commit()?; + let mut model = SAM2::new(config)?; // Load image let xs = DataLoader::try_read_n(&["images/truck.jpg"])?; diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs index a209e27..dbc2c23 100644 --- a/examples/sapiens/main.rs +++ b/examples/sapiens/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::Sapiens, Annotator, DataLoader, Options}; +use usls::{models::Sapiens, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -17,10 +17,10 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build - let options = Options::sapiens_seg_0_3b() + let config = ModelConfig::sapiens_seg_0_3b() .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = Sapiens::new(options)?; + let mut model = Sapiens::new(config)?; // load let xs = DataLoader::try_read_n(&["images/paul-george.jpg"])?; diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs index 362671c..fe66322 100644 --- a/examples/slanet/main.rs +++ b/examples/slanet/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SLANet, Annotator, Color, DataLoader, Options}; +use usls::{models::SLANet, Annotator, Color, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -26,11 +26,11 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let options = Options::slanet_lcnet_v2_mobile_ch() + let config = ModelConfig::slanet_lcnet_v2_mobile_ch() .with_model_device(args.device.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?) .commit()?; - let mut model = SLANet::new(options)?; + let mut model = SLANet::new(config)?; // load let xs = DataLoader::try_read_n(&[args.source])?; diff --git a/examples/smolvlm/main.rs b/examples/smolvlm/main.rs index 3bc87fb..7069282 100644 --- a/examples/smolvlm/main.rs +++ b/examples/smolvlm/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SmolVLM, DataLoader, Options, Scale}; +use usls::{models::SmolVLM, DataLoader, ModelConfig, Scale}; #[derive(argh::FromArgs)] /// Example @@ -29,32 +29,15 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let (options_vision_encoder, options_text_embed, options_decode) = - match args.scale.as_str().try_into()? { - Scale::Million(256.) => ( - Options::smolvlm_vision_256m(), - Options::smolvlm_text_embed_256m(), - Options::smolvlm_decoder_256m(), - ), - Scale::Million(500.) => ( - Options::smolvlm_vision_500m(), - Options::smolvlm_text_embed_500m(), - Options::smolvlm_decoder_500m(), - ), - _ => unimplemented!(), - }; + let config = match args.scale.as_str().try_into()? { + Scale::Million(256.) => ModelConfig::smolvlm_256m(), + Scale::Million(500.) => ModelConfig::smolvlm_500m(), + _ => unimplemented!(), + } + .with_device_all(args.device.as_str().try_into()?) + .commit()?; - let mut model = SmolVLM::new( - options_vision_encoder - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - options_text_embed - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - options_decode - .with_model_device(args.device.as_str().try_into()?) - .commit()?, - )?; + let mut model = SmolVLM::new(config)?; // load images let xs = DataLoader::try_read_n(&args.source)?; diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs index bd92036..f39c464 100644 --- a/examples/svtr/main.rs +++ b/examples/svtr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::SVTR, DataLoader, Options}; +use usls::{models::SVTR, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -22,13 +22,13 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let options = Options::ppocr_rec_v4_ch() + let config = ModelConfig::ppocr_rec_v4_ch() // ppocr_rec_v4_en() // repsvtr_ch() .with_model_device(args.device.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?) .commit()?; - let mut model = SVTR::new(options)?; + let mut model = SVTR::new(config)?; // load images let dl = DataLoader::new("./examples/svtr/images")? diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs index f64b72e..79aad83 100644 --- a/examples/trocr/main.rs +++ b/examples/trocr/main.rs @@ -1,6 +1,6 @@ use usls::{ models::{TrOCR, TrOCRKind}, - DataLoader, Options, Scale, + DataLoader, ModelConfig, Scale, }; #[derive(argh::FromArgs)] @@ -38,52 +38,22 @@ fn main() -> anyhow::Result<()> { ])?; // build model - let (options_encoder, options_decoder, options_decoder_merged) = - match args.scale.as_str().try_into()? { - Scale::S => match args.kind.as_str().try_into()? { - TrOCRKind::Printed => ( - Options::trocr_encoder_small_printed(), - Options::trocr_decoder_small_printed(), - Options::trocr_decoder_merged_small_printed(), - ), - TrOCRKind::HandWritten => ( - Options::trocr_encoder_small_handwritten(), - Options::trocr_decoder_small_handwritten(), - Options::trocr_decoder_merged_small_handwritten(), - ), - }, - Scale::B => match args.kind.as_str().try_into()? { - TrOCRKind::Printed => ( - Options::trocr_encoder_base_printed(), - Options::trocr_decoder_base_printed(), - Options::trocr_decoder_merged_base_printed(), - ), - TrOCRKind::HandWritten => ( - Options::trocr_encoder_base_handwritten(), - Options::trocr_decoder_base_handwritten(), - Options::trocr_decoder_merged_base_handwritten(), - ), - }, - x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), - }; + let config = match args.scale.as_str().try_into()? { + Scale::S => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ModelConfig::trocr_small_printed(), + TrOCRKind::HandWritten => ModelConfig::trocr_small_handwritten(), + }, + Scale::B => match args.kind.as_str().try_into()? { + TrOCRKind::Printed => ModelConfig::trocr_base_printed(), + TrOCRKind::HandWritten => ModelConfig::trocr_base_handwritten(), + }, + x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x), + } + .with_device_all(args.device.as_str().try_into()?) + .with_dtype_all(args.dtype.as_str().try_into()?) + .commit()?; - let mut model = TrOCR::new( - options_encoder - .with_model_device(args.device.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_decoder - .with_model_device(args.device.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - options_decoder_merged - .with_model_device(args.device.as_str().try_into()?) - .with_model_dtype(args.dtype.as_str().try_into()?) - .with_batch_size(xs.len()) - .commit()?, - )?; + let mut model = TrOCR::new(config)?; // inference let ys = model.forward(&xs)?; diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs index 8b60d87..ed3ca1c 100644 --- a/examples/yolo-sam2/main.rs +++ b/examples/yolo-sam2/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ models::{SamPrompt, SAM2, YOLO}, - Annotator, DataLoader, Options, Scale, Style, + Annotator, DataLoader, ModelConfig, Scale, Style, }; #[derive(argh::FromArgs)] @@ -21,17 +21,14 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build SAM - let (options_encoder, options_decoder) = ( - Options::sam2_1_tiny_encoder().commit()?, - Options::sam2_1_tiny_decoder().commit()?, - ); - let mut sam = SAM2::new(options_encoder, options_decoder)?; + let mut sam = SAM2::new(ModelConfig::sam2_1_tiny().commit()?)?; // build YOLOv8 - let options_yolo = Options::yolo_detect() - .with_model_scale(Scale::N) - .with_model_version(8.into()) + let options_yolo = ModelConfig::yolo_detect() + .with_scale(Scale::N) + .with_version(8.into()) .with_model_device(args.device.as_str().try_into()?) + .auto_yolo_model_file() .commit()?; let mut yolo = YOLO::new(options_yolo)?; diff --git a/examples/yolo/README.md b/examples/yolo/README.md index d53104e..d3d214f 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -54,7 +54,7 @@ cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 - cargo run -r --example yolo -- --ver 11 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv11-Obb ``` -**`cargo run -r --example yolo -- --help` for more options** +**`cargo run -r --example yolo -- --help` for more config** ## Other YOLOv8 Solution Models diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 86111bd..4f69ad9 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use usls::{ - models::YOLO, Annotator, DataLoader, Options, Style, NAMES_COCO_80, NAMES_COCO_KEYPOINTS_17, - NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19, + models::YOLO, Annotator, DataLoader, ModelConfig, Style, NAMES_COCO_80, + NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19, }; #[derive(argh::FromArgs, Debug)] @@ -132,14 +132,15 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); - let mut options = Options::yolo() + let mut config = ModelConfig::yolo() .with_model_file(&args.model.unwrap_or_default()) - .with_model_task(args.task.as_str().try_into()?) - .with_model_version(args.ver.try_into()?) - .with_model_scale(args.scale.as_str().try_into()?) + .with_task(args.task.as_str().try_into()?) + .with_version(args.ver.try_into()?) + .with_scale(args.scale.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) - .with_trt_fp16(args.trt_fp16) + // .with_trt_fp16(args.trt_fp16) + .with_model_trt_fp16(args.trt_fp16) .with_model_ixx( 0, 0, @@ -175,27 +176,27 @@ fn main() -> Result<()> { .exclude_classes(&args.exclude_classes); if args.use_coco_80_classes { - options = options.with_class_names(&NAMES_COCO_80); + config = config.with_class_names(&NAMES_COCO_80); } if args.use_coco_17_keypoints_classes { - options = options.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17); + config = config.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17); } if args.use_imagenet_1k_classes { - options = options.with_class_names(&NAMES_IMAGENET_1K); + config = config.with_class_names(&NAMES_IMAGENET_1K); } if let Some(nc) = args.num_classes { - options = options.with_nc(nc); + config = config.with_nc(nc); } if let Some(nk) = args.num_keypoints { - options = options.with_nk(nk); + config = config.with_nk(nk); } if !args.class_names.is_empty() { - options = options.with_class_names( + config = config.with_class_names( &args .class_names .iter() @@ -205,7 +206,7 @@ fn main() -> Result<()> { } if !args.keypoint_names.is_empty() { - options = options.with_keypoint_names( + config = config.with_keypoint_names( &args .keypoint_names .iter() @@ -215,7 +216,7 @@ fn main() -> Result<()> { } // build model - let mut model = YOLO::try_from(options.commit()?)?; + let mut model = YOLO::try_from(config.auto_yolo_model_file().commit()?)?; // build dataloader let dl = DataLoader::new(&args.source)? diff --git a/examples/yoloe/README.md b/examples/yoloe/README.md index 110c5a4..005ff33 100644 --- a/examples/yoloe/README.md +++ b/examples/yoloe/README.md @@ -1,6 +1,6 @@ ## Quick Start ```shell -cargo run -r --example yoloe +cargo run -r -F cuda --example yoloe -- --device cuda ``` diff --git a/examples/yoloe/main.rs b/examples/yoloe/main.rs index 841945c..175cbb2 100644 --- a/examples/yoloe/main.rs +++ b/examples/yoloe/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, Options, Style}; +use usls::{models::YOLO, Annotator, DataLoader, ModelConfig, Style}; #[derive(argh::FromArgs)] /// Example @@ -21,8 +21,8 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); - // options - let options = Options::yoloe_v8s_seg_pf() + // config + let config = ModelConfig::yoloe_v8s_seg_pf() // yoloe_v8m_seg_pf() // yoloe_v8l_seg_pf() // yoloe_11s_seg_pf() @@ -31,7 +31,7 @@ fn main() -> Result<()> { .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; - let mut model = YOLO::new(options)?; + let mut model = YOLO::new(config)?; // load let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; diff --git a/examples/yolop/main.rs b/examples/yolop/main.rs index 7f062db..767926b 100644 --- a/examples/yolop/main.rs +++ b/examples/yolop/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLOPv2, Annotator, DataLoader, Options}; +use usls::{models::YOLOPv2, Annotator, DataLoader, ModelConfig}; fn main() -> Result<()> { tracing_subscriber::fmt() @@ -8,8 +8,7 @@ fn main() -> Result<()> { .init(); // build model - let options = Options::yolop_v2_480x800().commit()?; - let mut model = YOLOPv2::new(options)?; + let mut model = YOLOPv2::new(ModelConfig::yolop_v2_480x800().commit()?)?; // load image let xs = DataLoader::try_read_n(&["images/car-view.jpg"])?; diff --git a/examples/yolov8-rtdetr/main.rs b/examples/yolov8-rtdetr/main.rs index 3b96e8e..b085f28 100644 --- a/examples/yolov8-rtdetr/main.rs +++ b/examples/yolov8-rtdetr/main.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use usls::{models::YOLO, Annotator, DataLoader, Options}; +use usls::{models::YOLO, Annotator, DataLoader, ModelConfig}; #[derive(argh::FromArgs)] /// Example @@ -22,7 +22,7 @@ fn main() -> Result<()> { let args: Args = argh::from_env(); // build model - let config = Options::yolo_v8_rtdetr_l() + let config = ModelConfig::yolo_v8_rtdetr_l() .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .commit()?; diff --git a/src/inference/engine.rs b/src/inference/engine.rs index ad6cdf0..cd2a9f3 100644 --- a/src/inference/engine.rs +++ b/src/inference/engine.rs @@ -13,8 +13,8 @@ use prost::Message; use std::collections::HashSet; use crate::{ - build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, Iiix, MinOptMax, Ops, Ts, - Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X, + build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, EngineConfig, Iiix, + MinOptMax, Ops, Ts, Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X, }; impl From for DType { @@ -93,6 +93,20 @@ impl Default for Engine { } impl Engine { + pub fn try_from_config(config: &EngineConfig) -> Result { + Self { + file: config.file.clone(), + spec: config.spec.clone(), + iiixs: config.iiixs.clone(), + device: config.device, + trt_fp16: config.trt_fp16, + num_dry_run: config.num_dry_run, + graph_opt_level: config.ort_graph_opt_level, + ..Default::default() + } + .build() + } + pub fn build(mut self) -> Result { let name = format!("[{}] ort_initialization", self.spec); elapsed!(&name, self.ts, { diff --git a/src/inference/engine_config.rs b/src/inference/engine_config.rs new file mode 100644 index 0000000..c63f251 --- /dev/null +++ b/src/inference/engine_config.rs @@ -0,0 +1,112 @@ +use aksr::Builder; +use anyhow::Result; + +use crate::{try_fetch_file_stem, DType, Device, Hub, Iiix, MinOptMax}; + +#[derive(Builder, Debug, Clone, Default)] +pub struct EngineConfig { + pub file: String, + pub device: Device, + pub iiixs: Vec, + pub num_dry_run: usize, + pub trt_fp16: bool, + pub ort_graph_opt_level: Option, + pub spec: String, // TODO: move out + pub dtype: DType, // For dynamically loading the model +} + +impl EngineConfig { + pub fn try_commit(mut self, name: &str) -> Result { + // Identify the local model or fetch the remote model + if std::path::PathBuf::from(&self.file).exists() { + // Local + self.spec = format!("{}/{}", name, try_fetch_file_stem(&self.file)?); + } else { + if self.file.is_empty() && name.is_empty() { + anyhow::bail!( + "Failed to commit model. Invalid model config: neither `name` nor `file` were specified. Failed to fetch model from Hub." + ) + } + + // Remote + match Hub::is_valid_github_release_url(&self.file) { + Some((owner, repo, tag, _file_name)) => { + let stem = try_fetch_file_stem(&self.file)?; + self.spec = format!("{}/{}-{}-{}-{}", name, owner, repo, tag, stem); + self.file = Hub::default().try_fetch(&self.file)?; + } + None => { + // append dtype to model file + match self.dtype { + d @ (DType::Auto | DType::Fp32) => { + if self.file.is_empty() { + self.file = format!("{}.onnx", d); + } + } + dtype => { + if self.file.is_empty() { + self.file = format!("{}.onnx", dtype); + } else { + let pos = self.file.len() - 5; // .onnx + let suffix = self.file.split_off(pos); + self.file = format!("{}-{}{}", self.file, dtype, suffix); + } + } + } + + let stem = try_fetch_file_stem(&self.file)?; + self.spec = format!("{}/{}", name, stem); + self.file = Hub::default().try_fetch(&format!("{}/{}", name, self.file))?; + } + } + } + + Ok(self) + } +} + +impl EngineConfig { + pub fn with_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { + self.iiixs.push(Iiix::from((i, ii, x))); + self + } + + pub fn with_batch_size(mut self, x: MinOptMax) -> Self { + self.iiixs.push(Iiix::from((0, 0, x))); + self + } +} + +#[macro_export] +macro_rules! impl_model_config_methods { + ($ty:ty, $field:ident) => { + impl $ty { + paste::paste! { + pub fn [](mut self, file: &str) -> Self { + self.$field = self.$field.with_file(file); + self + } + pub fn [](mut self, dtype: $crate::DType) -> Self { + self.$field = self.$field.with_dtype(dtype); + self + } + pub fn [](mut self, device: $crate::Device) -> Self { + self.$field = self.$field.with_device(device); + self + } + pub fn [](mut self, x: bool) -> Self { + self.$field = self.$field.with_trt_fp16(x); + self + } + pub fn [](mut self, x: usize) -> Self { + self.$field = self.$field.with_num_dry_run(x); + self + } + pub fn [](mut self, i: usize, ii: usize, x: $crate::MinOptMax) -> Self { + self.$field = self.$field.with_ixx(i, ii, x); + self + } + } + } + }; +} diff --git a/src/inference/image.rs b/src/inference/image.rs index 50c4778..23b113f 100644 --- a/src/inference/image.rs +++ b/src/inference/image.rs @@ -308,12 +308,12 @@ impl Image { )); } - let (mut resizer, options) = build_resizer_filter(filter)?; + let (mut resizer, config) = build_resizer_filter(filter)?; let x: DynamicImage = self.to_dyn(); if let ResizeMode::FitExact = mode { let mut dst = FImage::new(tw, th, PixelType::U8x3); - resizer.resize(&x, &mut dst, &options)?; + resizer.resize(&x, &mut dst, &config)?; trans_info = trans_info .with_height_scale(th as f32 / h0 as f32) .with_width_scale(tw as f32 / w0 as f32); @@ -362,7 +362,7 @@ impl Image { }; let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; - resizer.resize(&x, &mut dst_cropped, &options)?; + resizer.resize(&x, &mut dst_cropped, &config)?; Ok((Self::from_u8s(&dst.into_vec(), tw, th)?, trans_info)) } diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 1cb35bd..8e5783c 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,10 +1,12 @@ #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] mod engine; +mod engine_config; mod hbb; mod image; mod instance_meta; mod keypoint; mod mask; +mod model_config; mod obb; mod polygon; mod prob; @@ -20,11 +22,13 @@ pub(crate) mod onnx { #[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))] pub use engine::*; +pub use engine_config::EngineConfig; pub use hbb::*; pub use image::*; pub use instance_meta::*; pub use keypoint::*; pub use mask::*; +pub use model_config::*; pub use obb::*; pub use polygon::*; pub use prob::*; diff --git a/src/inference/model_config.rs b/src/inference/model_config.rs new file mode 100644 index 0000000..6870cfe --- /dev/null +++ b/src/inference/model_config.rs @@ -0,0 +1,243 @@ +use aksr::Builder; + +use crate::{ + impl_model_config_methods, impl_process_config_methods, + models::{SamKind, YOLOPredsFormat}, + EngineConfig, ProcessorConfig, Scale, Task, Version, +}; + +/// ModelConfig for building models and inference +#[derive(Builder, Debug, Clone)] +pub struct ModelConfig { + // Basics + pub name: &'static str, + pub version: Option, + pub task: Option, + pub scale: Option, + + // Engines + pub model: EngineConfig, + pub visual: EngineConfig, + pub textual: EngineConfig, + pub encoder: EngineConfig, + pub decoder: EngineConfig, + pub visual_encoder: EngineConfig, + pub textual_encoder: EngineConfig, + pub visual_decoder: EngineConfig, + pub textual_decoder: EngineConfig, + pub textual_decoder_merged: EngineConfig, + pub size_encoder: EngineConfig, + pub size_decoder: EngineConfig, + pub coord_encoder: EngineConfig, + pub coord_decoder: EngineConfig, + pub visual_projection: EngineConfig, + pub textual_projection: EngineConfig, + + // Processor + pub processor: ProcessorConfig, + + // Others + pub class_names: Option>, // TODO: remove Option + pub keypoint_names: Option>, // TODO: remove Option + pub text_names: Option>, // TODO: remove Option + pub class_confs: Vec, + pub keypoint_confs: Vec, + pub text_confs: Vec, + pub apply_softmax: Option, + pub topk: Option, + #[args(aka = "nc")] + pub num_classes: Option, + #[args(aka = "nk")] + pub num_keypoints: Option, + #[args(aka = "nm")] + pub num_masks: Option, + pub iou: Option, + pub apply_nms: Option, + pub find_contours: bool, + pub yolo_preds_format: Option, + pub classes_excluded: Vec, + pub classes_retained: Vec, + pub min_width: Option, + pub min_height: Option, + pub db_unclip_ratio: Option, + pub db_binary_thresh: Option, + pub sam_kind: Option, + pub sam_low_res_mask: Option, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + class_names: None, + keypoint_names: None, + text_names: None, + class_confs: vec![0.25f32], + keypoint_confs: vec![0.3f32], + text_confs: vec![0.25f32], + apply_softmax: Some(false), + num_classes: None, + num_keypoints: None, + num_masks: None, + iou: None, + find_contours: false, + yolo_preds_format: None, + classes_excluded: vec![], + classes_retained: vec![], + apply_nms: None, + min_width: None, + min_height: None, + db_unclip_ratio: Some(1.5), + db_binary_thresh: Some(0.2), + sam_kind: None, + sam_low_res_mask: None, + topk: None, + model: Default::default(), + encoder: Default::default(), + decoder: Default::default(), + visual: Default::default(), + textual: Default::default(), + visual_encoder: Default::default(), + textual_encoder: Default::default(), + visual_decoder: Default::default(), + textual_decoder: Default::default(), + textual_decoder_merged: Default::default(), + processor: ProcessorConfig::default(), + size_encoder: Default::default(), + size_decoder: Default::default(), + coord_encoder: Default::default(), + coord_decoder: Default::default(), + visual_projection: Default::default(), + textual_projection: Default::default(), + version: None, + task: None, + scale: None, + name: Default::default(), + } + } +} + +impl ModelConfig { + pub fn exclude_classes(mut self, xs: &[usize]) -> Self { + self.classes_retained.clear(); + self.classes_excluded.extend_from_slice(xs); + self + } + + pub fn retain_classes(mut self, xs: &[usize]) -> Self { + self.classes_excluded.clear(); + self.classes_retained.extend_from_slice(xs); + self + } + + pub fn commit(mut self) -> anyhow::Result { + fn try_commit(name: &str, mut m: EngineConfig) -> anyhow::Result { + if !m.file.is_empty() { + m = m.try_commit(name)?; + return Ok(m); + } + Ok(m) + } + + self.model = try_commit(self.name, self.model)?; + self.visual = try_commit(self.name, self.visual)?; + self.textual = try_commit(self.name, self.textual)?; + self.encoder = try_commit(self.name, self.encoder)?; + self.decoder = try_commit(self.name, self.decoder)?; + self.visual_encoder = try_commit(self.name, self.visual_encoder)?; + self.textual_encoder = try_commit(self.name, self.textual_encoder)?; + self.visual_decoder = try_commit(self.name, self.visual_decoder)?; + self.textual_decoder = try_commit(self.name, self.textual_decoder)?; + self.textual_decoder_merged = try_commit(self.name, self.textual_decoder_merged)?; + self.size_encoder = try_commit(self.name, self.size_encoder)?; + self.size_decoder = try_commit(self.name, self.size_decoder)?; + self.coord_encoder = try_commit(self.name, self.coord_encoder)?; + self.coord_decoder = try_commit(self.name, self.coord_decoder)?; + self.visual_projection = try_commit(self.name, self.visual_projection)?; + self.textual_projection = try_commit(self.name, self.textual_projection)?; + + Ok(self) + } + + pub fn with_batch_size_all(mut self, batch_size: usize) -> Self { + self.visual = self.visual.with_ixx(0, 0, batch_size.into()); + self.textual = self.textual.with_ixx(0, 0, batch_size.into()); + self.model = self.model.with_ixx(0, 0, batch_size.into()); + self.encoder = self.encoder.with_ixx(0, 0, batch_size.into()); + self.decoder = self.decoder.with_ixx(0, 0, batch_size.into()); + self.visual_encoder = self.visual_encoder.with_ixx(0, 0, batch_size.into()); + self.textual_encoder = self.textual_encoder.with_ixx(0, 0, batch_size.into()); + self.visual_decoder = self.visual_decoder.with_ixx(0, 0, batch_size.into()); + self.textual_decoder = self.textual_decoder.with_ixx(0, 0, batch_size.into()); + self.textual_decoder_merged = self + .textual_decoder_merged + .with_ixx(0, 0, batch_size.into()); + self.size_encoder = self.size_encoder.with_ixx(0, 0, batch_size.into()); + self.size_decoder = self.size_decoder.with_ixx(0, 0, batch_size.into()); + self.coord_encoder = self.coord_encoder.with_ixx(0, 0, batch_size.into()); + self.coord_decoder = self.coord_decoder.with_ixx(0, 0, batch_size.into()); + self.visual_projection = self.visual_projection.with_ixx(0, 0, batch_size.into()); + self.textual_projection = self.textual_projection.with_ixx(0, 0, batch_size.into()); + + self + } + + pub fn with_device_all(mut self, device: crate::Device) -> Self { + self.visual = self.visual.with_device(device); + self.textual = self.textual.with_device(device); + self.model = self.model.with_device(device); + self.encoder = self.encoder.with_device(device); + self.decoder = self.decoder.with_device(device); + self.visual_encoder = self.visual_encoder.with_device(device); + self.textual_encoder = self.textual_encoder.with_device(device); + self.visual_decoder = self.visual_decoder.with_device(device); + self.textual_decoder = self.textual_decoder.with_device(device); + self.textual_decoder_merged = self.textual_decoder_merged.with_device(device); + self.size_encoder = self.size_encoder.with_device(device); + self.size_decoder = self.size_decoder.with_device(device); + self.coord_encoder = self.coord_encoder.with_device(device); + self.coord_decoder = self.coord_decoder.with_device(device); + self.visual_projection = self.visual_projection.with_device(device); + self.textual_projection = self.textual_projection.with_device(device); + + self + } + + pub fn with_dtype_all(mut self, dtype: crate::DType) -> Self { + self.visual = self.visual.with_dtype(dtype); + self.textual = self.textual.with_dtype(dtype); + self.model = self.model.with_dtype(dtype); + self.encoder = self.encoder.with_dtype(dtype); + self.decoder = self.decoder.with_dtype(dtype); + self.visual_encoder = self.visual_encoder.with_dtype(dtype); + self.textual_encoder = self.textual_encoder.with_dtype(dtype); + self.visual_decoder = self.visual_decoder.with_dtype(dtype); + self.textual_decoder = self.textual_decoder.with_dtype(dtype); + self.textual_decoder_merged = self.textual_decoder_merged.with_dtype(dtype); + self.size_encoder = self.size_encoder.with_dtype(dtype); + self.size_decoder = self.size_decoder.with_dtype(dtype); + self.coord_encoder = self.coord_encoder.with_dtype(dtype); + self.coord_decoder = self.coord_decoder.with_dtype(dtype); + self.visual_projection = self.visual_projection.with_dtype(dtype); + self.textual_projection = self.textual_projection.with_dtype(dtype); + + self + } +} + +impl_model_config_methods!(ModelConfig, model); +impl_model_config_methods!(ModelConfig, visual); +impl_model_config_methods!(ModelConfig, textual); +impl_model_config_methods!(ModelConfig, encoder); +impl_model_config_methods!(ModelConfig, decoder); +impl_model_config_methods!(ModelConfig, visual_encoder); +impl_model_config_methods!(ModelConfig, textual_encoder); +impl_model_config_methods!(ModelConfig, visual_decoder); +impl_model_config_methods!(ModelConfig, textual_decoder); +impl_model_config_methods!(ModelConfig, textual_decoder_merged); +impl_model_config_methods!(ModelConfig, size_encoder); +impl_model_config_methods!(ModelConfig, size_decoder); +impl_model_config_methods!(ModelConfig, coord_encoder); +impl_model_config_methods!(ModelConfig, coord_decoder); +impl_model_config_methods!(ModelConfig, visual_projection); +impl_model_config_methods!(ModelConfig, textual_projection); +impl_process_config_methods!(ModelConfig, processor); diff --git a/src/io/dataloader.rs b/src/io/dataloader.rs index 420fdff..19aed4e 100644 --- a/src/io/dataloader.rs +++ b/src/io/dataloader.rs @@ -367,14 +367,14 @@ impl DataLoader { fn load_image_paths_from_folder(source: &str, exts: &[&str]) -> Result> { let source_path = Path::new(source); let mut paths: Vec = Vec::new(); - let options = MatchOptions { + let config = MatchOptions { case_sensitive: false, require_literal_separator: false, require_literal_leading_dot: false, }; for ext in exts.iter() { let pattern = source_path.join(format!("*.{}", ext)); - let paths_: Vec = glob_with(pattern.to_str().unwrap(), options)? + let paths_: Vec = glob_with(pattern.to_str().unwrap(), config)? .filter_map(|entry| entry.ok()) .collect(); paths.extend(paths_); @@ -393,12 +393,12 @@ impl DataLoader { } fn glob(pattern: &str, sort: bool, case_sensitive: bool) -> anyhow::Result> { - let options = MatchOptions { + let config = MatchOptions { case_sensitive, require_literal_separator: false, require_literal_leading_dot: false, }; - let mut paths: Vec = glob_with(pattern, options)? + let mut paths: Vec = glob_with(pattern, config)? .filter_map(|entry| entry.ok()) .collect(); @@ -479,7 +479,7 @@ impl DataLoader { self } - pub fn with_batch_size(mut self, x: usize) -> Self { + pub fn with_batch_size_all(mut self, x: usize) -> Self { self.batch_size = x; self } diff --git a/src/models/beit/config.rs b/src/models/beit/config.rs index 46636b6..34eb389 100644 --- a/src/models/beit/config.rs +++ b/src/models/beit/config.rs @@ -1,10 +1,8 @@ -use crate::NAMES_IMAGENET_1K; - /// Model configuration for `BEiT` -impl crate::Options { +impl crate::ModelConfig { pub fn beit() -> Self { Self::default() - .with_model_name("beit") + .with_name("beit") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) @@ -13,7 +11,7 @@ impl crate::Options { .with_image_std(&[0.5, 0.5, 0.5]) .with_normalize(true) .with_apply_softmax(true) - .with_class_names(&NAMES_IMAGENET_1K) + .with_class_names(&crate::NAMES_IMAGENET_1K) } pub fn beit_base() -> Self { diff --git a/src/models/ben2/config.rs b/src/models/ben2/config.rs index b6b5632..f942025 100644 --- a/src/models/ben2/config.rs +++ b/src/models/ben2/config.rs @@ -1,5 +1,5 @@ /// Model configuration for `BEN2` -impl crate::Options { +impl crate::ModelConfig { pub fn ben2_base() -> Self { Self::rmbg().with_model_file("ben2-base.onnx") } diff --git a/src/models/blip/config.rs b/src/models/blip/config.rs index 038e046..0395d54 100644 --- a/src/models/blip/config.rs +++ b/src/models/blip/config.rs @@ -1,34 +1,24 @@ /// Model configuration for `BLIP` -impl crate::Options { - pub fn blip() -> Self { - Self::default().with_model_name("blip").with_batch_size(1) - } - +impl crate::ModelConfig { #[allow(clippy::excessive_precision)] - pub fn blip_visual() -> Self { - Self::blip() - .with_model_kind(crate::Kind::Vision) - .with_model_ixx(0, 2, 384.into()) - .with_model_ixx(0, 3, 384.into()) + pub fn blip() -> Self { + Self::default() + .with_name("blip") + .with_batch_size_all(1) + .with_visual_ixx(0, 1, 3.into()) + .with_visual_ixx(0, 2, 384.into()) + .with_visual_ixx(0, 3, 384.into()) .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) .with_image_std(&[0.26862954, 0.26130258, 0.27577711]) - .with_resize_filter("Bilinear") - .with_normalize(true) } - pub fn blip_textual() -> Self { - Self::blip().with_model_kind(crate::Kind::Language) - } - - pub fn blip_v1_base_caption_visual() -> Self { - Self::blip_visual() - .with_model_version(1.into()) - .with_model_file("v1-base-caption-visual.onnx") - } - - pub fn blip_v1_base_caption_textual() -> Self { - Self::blip_textual() - .with_model_version(1.into()) - .with_model_file("v1-base-caption-textual.onnx") + pub fn blip_v1_base_caption() -> Self { + Self::blip() + .with_version(1.into()) + .with_visual_file("v1-base-caption-visual.onnx") + .with_textual_file("v1-base-caption-textual.onnx") + .with_tokenizer_file("blip/tokenizer.json") + .with_tokenizer_config_file("blip/tokenizer_config.json") + .with_special_tokens_map_file("blip/special_tokens_map.json") } } diff --git a/src/models/blip/impl.rs b/src/models/blip/impl.rs index 3719f8a..3a4e922 100644 --- a/src/models/blip/impl.rs +++ b/src/models/blip/impl.rs @@ -2,26 +2,34 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; -use crate::{ - elapsed, - models::{BaseModelTextual, BaseModelVisual}, - Image, LogitsSampler, Options, Ts, Xs, X, Y, -}; +use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct Blip { - visual: BaseModelVisual, - textual: BaseModelTextual, - ts: Ts, + visual: Engine, + textual: Engine, + batch: usize, + height: usize, + width: usize, + processor: Processor, max_length: usize, eos_token_id: u32, + ts: Ts, } impl Blip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let visual = BaseModelVisual::new(options_visual)?; - let textual = BaseModelTextual::new(options_textual)?; - let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); + pub fn new(config: ModelConfig) -> Result { + let visual = Engine::try_from_config(&config.visual)?; + let textual = Engine::try_from_config(&config.textual)?; + let (batch, height, width) = ( + visual.batch().opt(), + visual.try_height().unwrap_or(&384.into()).opt(), + visual.try_width().unwrap_or(&384.into()).opt(), + ); + let ts = Ts::merge(&[visual.ts(), textual.ts()]); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); let max_length = 512; let eos_token_id = 102; @@ -31,17 +39,24 @@ impl Blip { ts, max_length, eos_token_id, + batch, + height, + width, + processor, }) } pub fn encode_images(&mut self, xs: &[Image]) -> Result { - self.visual.encode(xs) + let ys = self.processor.process_images(xs)?; + self.batch = xs.len(); // update + let ys = self.visual.run(ys.into())?; + + Ok(ys[0].to_owned()) } pub fn encode_texts(&mut self, text: Option<&str>) -> Result>> { let input_ids = self - .textual - .processor() + .processor .encode_text_ids(text.unwrap_or_default(), false)?; Ok(vec![input_ids.clone(); self.batch()]) } @@ -70,11 +85,11 @@ impl Blip { let input_ids_attn_mask = X::ones(input_ids_nd.dims()); // decode - let outputs = self.textual.inference(Xs::from(vec![ + let outputs = self.textual.run(Xs::from(vec![ input_ids_nd, input_ids_attn_mask, image_embeds.clone(), - X::ones(&[self.visual().batch(), image_embeds.dims()[1]]), // image_embeds_attn_mask + X::ones(&[self.batch(), image_embeds.dims()[1]]), ]))?; // decode each token for each batch @@ -102,7 +117,7 @@ impl Blip { } // batch decode - let texts = self.textual.processor().decode_tokens_batch( + let texts = self.processor.decode_tokens_batch( &token_ids .into_iter() .map(|v| v.into_iter().map(|x| x as u32).collect::>()) @@ -114,7 +129,6 @@ impl Blip { .into_iter() .map(|x| Y::default().with_texts(&[&x])) .collect::>(); - // .into(); Ok(ys) } @@ -122,8 +136,4 @@ impl Blip { pub fn summary(&mut self) { self.ts.summary(); } - - pub fn batch(&self) -> usize { - self.visual.batch() as _ - } } diff --git a/src/models/clip/config.rs b/src/models/clip/config.rs index 0454261..d0712b2 100644 --- a/src/models/clip/config.rs +++ b/src/models/clip/config.rs @@ -1,71 +1,57 @@ -use crate::Kind; - /// Model configuration for `CLIP` -impl crate::Options { +impl crate::ModelConfig { pub fn clip() -> Self { Self::default() - .with_model_name("clip") - .with_model_ixx(0, 0, 1.into()) - } - - pub fn clip_visual() -> Self { - Self::clip() - .with_model_kind(Kind::Vision) - .with_model_ixx(0, 2, 224.into()) - .with_model_ixx(0, 3, 224.into()) + .with_name("clip") + .with_batch_size_all(1) + .with_visual_ixx(0, 1, 3.into()) + .with_visual_ixx(0, 2, 224.into()) + .with_visual_ixx(0, 3, 224.into()) .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) - } - - pub fn clip_textual() -> Self { - Self::clip() - .with_model_kind(Kind::Language) .with_model_max_length(77) + .with_tokenizer_file("clip/tokenizer.json") + .with_tokenizer_config_file("clip/tokenizer_config.json") + .with_special_tokens_map_file("clip/special_tokens_map.json") + .with_config_file("clip/config.json") } - pub fn clip_vit_b16_visual() -> Self { - Self::clip_visual().with_model_file("vit-b16-visual.onnx") + pub fn clip_vit_b16() -> Self { + Self::clip() + .with_visual_file("vit-b16-visual.onnx") + .with_textual_file("vit-b16-textual.onnx") } - pub fn clip_vit_b16_textual() -> Self { - Self::clip_textual().with_model_file("vit-b16-textual.onnx") + pub fn clip_vit_b32() -> Self { + Self::clip() + .with_visual_file("vit-b32-visual.onnx") + .with_textual_file("vit-b32-textual.onnx") } - pub fn clip_vit_b32_visual() -> Self { - Self::clip_visual().with_model_file("vit-b32-visual.onnx") + pub fn clip_vit_l14() -> Self { + Self::clip() + .with_visual_file("vit-l14-visual.onnx") + .with_textual_file("vit-l14-textual.onnx") } - pub fn clip_vit_b32_textual() -> Self { - Self::clip_textual().with_model_file("vit-b32-textual.onnx") - } - - pub fn clip_vit_l14_visual() -> Self { - Self::clip_visual().with_model_file("vit-l14-visual.onnx") - } - - pub fn clip_vit_l14_textual() -> Self { - Self::clip_textual().with_model_file("vit-l14-textual.onnx") + pub fn jina_clip() -> Self { + Self::default() + .with_name("jina-clip-v1") + .with_batch_size_all(1) + .with_visual_ixx(0, 1, 3.into()) + .with_visual_ixx(0, 2, 224.into()) + .with_visual_ixx(0, 3, 224.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) + .with_tokenizer_file("jina-clip-v1/tokenizer.json") + .with_tokenizer_config_file("jina-clip-v1/tokenizer_config.json") + .with_special_tokens_map_file("jina-clip-v1/special_tokens_map.json") + .with_config_file("jina-clip-v1/config.json") } pub fn jina_clip_v1() -> Self { - Self::default() - .with_model_name("jina-clip-v1") - .with_model_ixx(0, 0, 1.into()) - } - - pub fn jina_clip_v1_visual() -> Self { - Self::jina_clip_v1() - .with_model_kind(Kind::Vision) - .with_model_ixx(0, 2, 224.into()) - .with_model_ixx(0, 3, 224.into()) - .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) - .with_image_std(&[0.26862954, 0.2613026, 0.2757771]) - .with_model_file("visual.onnx") - } - - pub fn jina_clip_v1_textual() -> Self { - Self::jina_clip_v1() - .with_model_kind(Kind::Language) - .with_model_file("textual.onnx") + Self::jina_clip() + .with_visual_file("visual.onnx") + .with_textual_file("textual.onnx") } } diff --git a/src/models/clip/impl.rs b/src/models/clip/impl.rs index 34f41ec..e3231b9 100644 --- a/src/models/clip/impl.rs +++ b/src/models/clip/impl.rs @@ -2,11 +2,12 @@ use aksr::Builder; use anyhow::Result; use ndarray::Array2; -use crate::{elapsed, Engine, Image, Options, Processor, Ts, Xs, X}; +use crate::{elapsed, Engine, Image, ModelConfig, Processor, Ts, X}; #[derive(Debug, Builder)] -pub struct ClipVisual { - engine: Engine, +pub struct Clip { + visual: Engine, + textual: Engine, height: usize, width: usize, batch: usize, @@ -14,22 +15,23 @@ pub struct ClipVisual { ts: Ts, } -impl ClipVisual { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; - let (batch, height, width, ts) = ( - engine.batch().opt(), - engine.try_height().unwrap_or(&224.into()).opt(), - engine.try_width().unwrap_or(&224.into()).opt(), - engine.ts.clone(), +impl Clip { + pub fn new(config: ModelConfig) -> Result { + let visual = Engine::try_from_config(&config.visual)?; + let textual = Engine::try_from_config(&config.textual)?; + let (batch, height, width) = ( + visual.batch().opt(), + visual.try_height().unwrap_or(&224.into()).opt(), + visual.try_width().unwrap_or(&224.into()).opt(), ); - let processor = options - .to_processor()? + let ts = Ts::merge(&[visual.ts(), textual.ts()]); + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); Ok(Self { - engine, + textual, + visual, height, width, batch, @@ -38,111 +40,39 @@ impl ClipVisual { }) } - pub fn preprocess(&mut self, xs: &[Image]) -> Result { - let x = self.processor.process_images(xs)?; - - Ok(x.into()) - } - - pub fn inference(&mut self, xs: Xs) -> Result { - self.engine.run(xs) - } - pub fn encode_images(&mut self, xs: &[Image]) -> Result { - let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? }); - let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? }); + let xs = elapsed!("visual-preprocess", self.ts, { + self.processor.process_images(xs)? + }); + let xs = elapsed!("visual-inference", self.ts, { self.visual.run(xs.into())? }); let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() }); Ok(x) } -} - -#[derive(Debug, Builder)] -pub struct ClipTextual { - engine: Engine, - batch: usize, - processor: Processor, - ts: Ts, -} - -impl ClipTextual { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; - let (batch, ts) = (engine.batch().opt(), engine.ts.clone()); - let processor = options.to_processor()?; - - Ok(Self { - engine, - batch, - processor, - ts, - }) - } - - pub fn preprocess(&self, xs: &[&str]) -> Result { - let encodings: Vec = self - .processor - .encode_texts_ids(xs, false)? // skip_special_tokens - .into_iter() - .flatten() - .collect(); - - let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)? - .into_dyn() - .into(); - - Ok(x.into()) - } - - pub fn inference(&mut self, xs: Xs) -> Result { - self.engine.run(xs) - } pub fn encode_texts(&mut self, xs: &[&str]) -> Result { - let xs = elapsed!("textual-preprocess", self.ts, { self.preprocess(xs)? }); - let xs = elapsed!("textual-inference", self.ts, { self.inference(xs)? }); + let xs = elapsed!("textual-preprocess", self.ts, { + let encodings: Vec = self + .processor + .encode_texts_ids(xs, false)? + .into_iter() + .flatten() + .collect(); + + let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)? + .into_dyn() + .into(); + x + }); + let xs = elapsed!("textual-inference", self.ts, { + self.textual.run(xs.into())? + }); let x = elapsed!("textual-postprocess", self.ts, { xs[0].to_owned() }); Ok(x) } -} - -#[derive(Debug, Builder)] -pub struct Clip { - textual: ClipTextual, - visual: ClipVisual, - ts: Ts, -} - -impl Clip { - pub fn new(options_visual: Options, options_textual: Options) -> Result { - let visual = ClipVisual::new(options_visual)?; - let textual = ClipTextual::new(options_textual)?; - // let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]); - let ts = Ts::default(); - - Ok(Self { - textual, - visual, - ts, - }) - } - - pub fn encode_images(&mut self, xs: &[Image]) -> Result { - let x = elapsed!("encode_images", self.ts, { self.visual.encode_images(xs)? }); - Ok(x) - } - - pub fn encode_texts(&mut self, xs: &[&str]) -> Result { - let x = elapsed!("encode_texts", self.ts, { self.textual.encode_texts(xs)? }); - Ok(x) - } pub fn summary(&mut self) { - // self.ts.clear(); - // self.ts = Ts::merge(&[&self.ts, self.visual.ts(), self.textual.ts()]); self.ts.summary(); - self.visual.ts().summary(); - self.textual.ts().summary(); } } diff --git a/src/models/convnext/config.rs b/src/models/convnext/config.rs index 0ca435d..30d0074 100644 --- a/src/models/convnext/config.rs +++ b/src/models/convnext/config.rs @@ -1,10 +1,10 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `ConvNeXt` -impl crate::Options { +impl crate::ModelConfig { pub fn convnext() -> Self { Self::default() - .with_model_name("convnext") + .with_name("convnext") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) diff --git a/src/models/d_fine/config.rs b/src/models/d_fine/config.rs index ea2cea7..16de585 100644 --- a/src/models/d_fine/config.rs +++ b/src/models/d_fine/config.rs @@ -1,7 +1,7 @@ /// Model configuration for `d_fine` -impl crate::Options { +impl crate::ModelConfig { pub fn d_fine() -> Self { - Self::rtdetr().with_model_name("d-fine") + Self::rtdetr().with_name("d-fine") } pub fn d_fine_n_coco() -> Self { diff --git a/src/models/db/config.rs b/src/models/db/config.rs index c594526..0493237 100644 --- a/src/models/db/config.rs +++ b/src/models/db/config.rs @@ -1,8 +1,8 @@ /// Model configuration for [DB](https://github.com/MhLiao/DB) and [PaddleOCR-Det](https://github.com/PaddlePaddle/PaddleOCR) -impl crate::Options { +impl crate::ModelConfig { pub fn db() -> Self { Self::default() - .with_model_name("db") + .with_name("db") .with_model_ixx(0, 0, (1, 1, 8).into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, (608, 960, 1600).into()) @@ -11,7 +11,7 @@ impl crate::Options { .with_normalize(true) .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) - .with_binary_thresh(0.2) + .with_db_binary_thresh(0.2) .with_class_confs(&[0.35]) .with_min_width(5.0) .with_min_height(12.0) diff --git a/src/models/db/impl.rs b/src/models/db/impl.rs index 651fa87..af7d874 100644 --- a/src/models/db/impl.rs +++ b/src/models/db/impl.rs @@ -4,7 +4,8 @@ use ndarray::Axis; use rayon::prelude::*; use crate::{ - elapsed, DynConf, Engine, Hbb, Image, Mask, Obb, Ops, Options, Polygon, Processor, Ts, Xs, Y, + elapsed, DynConf, Engine, Hbb, Image, Mask, ModelConfig, Obb, Ops, Polygon, Processor, Ts, Xs, + Y, }; #[derive(Debug, Builder)] @@ -24,8 +25,8 @@ pub struct DB { } impl DB { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts, spec) = ( engine.batch().opt(), engine.try_height().unwrap_or(&960.into()).opt(), @@ -33,15 +34,14 @@ impl DB { engine.ts.clone(), engine.spec().to_owned(), ); - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let confs = DynConf::new(options.class_confs(), 1); - let binary_thresh = options.binary_thresh().unwrap_or(0.2); - let unclip_ratio = options.unclip_ratio().unwrap_or(1.5); - let min_width = options.min_width().unwrap_or(12.0); - let min_height = options.min_height().unwrap_or(5.0); + let confs = DynConf::new(config.class_confs(), 1); + let binary_thresh = config.db_binary_thresh().unwrap_or(0.2); + let unclip_ratio = config.db_unclip_ratio().unwrap_or(1.5); + let min_width = config.min_width().unwrap_or(12.0); + let min_height = config.min_height().unwrap_or(5.0); Ok(Self { engine, diff --git a/src/models/deim/config.rs b/src/models/deim/config.rs index 10c4a0a..81177ae 100644 --- a/src/models/deim/config.rs +++ b/src/models/deim/config.rs @@ -1,7 +1,7 @@ /// Model configuration for `DEIM` -impl crate::Options { +impl crate::ModelConfig { pub fn deim() -> Self { - Self::d_fine().with_model_name("deim") + Self::d_fine().with_name("deim") } pub fn deim_dfine_s_coco() -> Self { diff --git a/src/models/deit/config.rs b/src/models/deit/config.rs index 8be8000..1f7c3b9 100644 --- a/src/models/deit/config.rs +++ b/src/models/deit/config.rs @@ -1,10 +1,10 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `DeiT` -impl crate::Options { +impl crate::ModelConfig { pub fn deit() -> Self { Self::default() - .with_model_name("deit") + .with_name("deit") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) diff --git a/src/models/depth_anything/config.rs b/src/models/depth_anything/config.rs index d19344d..e8adefd 100644 --- a/src/models/depth_anything/config.rs +++ b/src/models/depth_anything/config.rs @@ -1,8 +1,8 @@ /// Model configuration for `DepthAnything` -impl crate::Options { +impl crate::ModelConfig { pub fn depth_anything() -> Self { Self::default() - .with_model_name("depth-anything") + .with_name("depth-anything") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, (384, 518, 1024).into()) @@ -14,26 +14,26 @@ impl crate::Options { } pub fn depth_anything_s() -> Self { - Self::depth_anything().with_model_scale(crate::Scale::S) + Self::depth_anything().with_scale(crate::Scale::S) } pub fn depth_anything_v1() -> Self { - Self::depth_anything().with_model_version(1.into()) + Self::depth_anything().with_version(1.into()) } pub fn depth_anything_v2() -> Self { - Self::depth_anything().with_model_version(2.into()) + Self::depth_anything().with_version(2.into()) } pub fn depth_anything_v1_small() -> Self { Self::depth_anything_v1() - .with_model_scale(crate::Scale::S) + .with_scale(crate::Scale::S) .with_model_file("v1-s.onnx") } pub fn depth_anything_v2_small() -> Self { Self::depth_anything_v2() - .with_model_scale(crate::Scale::S) + .with_scale(crate::Scale::S) .with_model_file("v2-s.onnx") } } diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs index f094cc6..778fcc8 100644 --- a/src/models/depth_anything/impl.rs +++ b/src/models/depth_anything/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct DepthAnything { @@ -15,8 +15,8 @@ pub struct DepthAnything { } impl DepthAnything { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( @@ -25,9 +25,7 @@ impl DepthAnything { engine.try_width().unwrap_or(&518.into()).opt(), engine.ts().clone(), ); - - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); diff --git a/src/models/depth_pro/config.rs b/src/models/depth_pro/config.rs index 451682e..5cbddb1 100644 --- a/src/models/depth_pro/config.rs +++ b/src/models/depth_pro/config.rs @@ -1,8 +1,8 @@ /// Model configuration for `DepthPro` -impl crate::Options { +impl crate::ModelConfig { pub fn depth_pro() -> Self { Self::default() - .with_model_name("depth-pro") + .with_name("depth-pro") .with_model_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 1536.into()) @@ -12,16 +12,4 @@ impl crate::Options { .with_resize_mode(crate::ResizeMode::FitExact) .with_normalize(true) } - - // pub fn depth_pro_q4f16() -> Self { - // Self::depth_pro().with_model_file("q4f16.onnx") - // } - - // pub fn depth_pro_fp16() -> Self { - // Self::depth_pro().with_model_file("fp16.onnx") - // } - - // pub fn depth_pro_bnb4() -> Self { - // Self::depth_pro().with_model_file("bnb4.onnx") - // } } diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs index 6e1d254..6301437 100644 --- a/src/models/depth_pro/impl.rs +++ b/src/models/depth_pro/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct DepthPro { @@ -16,8 +16,8 @@ pub struct DepthPro { } impl DepthPro { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -25,8 +25,7 @@ impl DepthPro { engine.try_width().unwrap_or(&512.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); diff --git a/src/models/dinov2/config.rs b/src/models/dinov2/config.rs index abf7696..60df927 100644 --- a/src/models/dinov2/config.rs +++ b/src/models/dinov2/config.rs @@ -1,8 +1,8 @@ /// Model configuration for `DINOv2` -impl crate::Options { +impl crate::ModelConfig { pub fn dinov2() -> Self { Self::default() - .with_model_name("dinov2") + .with_name("dinov2") .with_model_ixx(0, 0, (1, 1, 8).into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) @@ -16,13 +16,13 @@ impl crate::Options { pub fn dinov2_small() -> Self { Self::dinov2() - .with_model_scale(crate::Scale::S) + .with_scale(crate::Scale::S) .with_model_file("s.onnx") } pub fn dinov2_base() -> Self { Self::dinov2() - .with_model_scale(crate::Scale::B) + .with_scale(crate::Scale::B) .with_model_file("b.onnx") } } diff --git a/src/models/dinov2/impl.rs b/src/models/dinov2/impl.rs index 813851d..c9a070c 100644 --- a/src/models/dinov2/impl.rs +++ b/src/models/dinov2/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, Options, Processor, Scale, Ts, Xs, X}; +use crate::{elapsed, Engine, Image, ModelConfig, Processor, Scale, Ts, Xs, X}; #[derive(Builder, Debug)] pub struct DINOv2 { @@ -15,15 +15,15 @@ pub struct DINOv2 { } impl DINOv2 { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&384.into()).opt(), engine.try_width().unwrap_or(&384.into()).opt(), engine.ts.clone(), ); - let dim = match options.model_scale() { + let dim = match &config.scale { Some(Scale::S) => 384, Some(Scale::B) => 768, Some(Scale::L) => 1024, @@ -31,8 +31,7 @@ impl DINOv2 { Some(x) => anyhow::bail!("Unsupported scale: {:?}", x), None => anyhow::bail!("No model scale specified"), }; - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); diff --git a/src/models/fast/config.rs b/src/models/fast/config.rs index 6ec39c4..02c4250 100644 --- a/src/models/fast/config.rs +++ b/src/models/fast/config.rs @@ -1,8 +1,8 @@ /// Model configuration for [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://github.com/czczup/FAST) -impl crate::Options { +impl crate::ModelConfig { pub fn fast() -> Self { Self::db() - .with_model_name("fast") + .with_name("fast") .with_image_mean(&[0.798, 0.785, 0.772]) .with_image_std(&[0.264, 0.2749, 0.287]) } diff --git a/src/models/fastvit/config.rs b/src/models/fastvit/config.rs index 71cdd31..10038ef 100644 --- a/src/models/fastvit/config.rs +++ b/src/models/fastvit/config.rs @@ -1,10 +1,10 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `FastViT` -impl crate::Options { +impl crate::ModelConfig { pub fn fastvit() -> Self { Self::default() - .with_model_name("fastvit") + .with_name("fastvit") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) diff --git a/src/models/florence2/config.rs b/src/models/florence2/config.rs index 8ef74ac..e1fac46 100644 --- a/src/models/florence2/config.rs +++ b/src/models/florence2/config.rs @@ -1,59 +1,31 @@ /// Model configuration for `Florence2` -impl crate::Options { +impl crate::ModelConfig { pub fn florence2() -> Self { Self::default() - .with_model_name("florence2") - .with_batch_size(1) - } - - pub fn florence2_visual() -> Self { - Self::florence2() - .with_model_kind(crate::Kind::Vision) - .with_model_ixx(0, 2, 768.into()) - .with_model_ixx(0, 3, 768.into()) + .with_name("florence2") + .with_batch_size_all(1) + .with_visual_ixx(0, 1, 3.into()) + .with_visual_ixx(0, 2, 768.into()) + .with_visual_ixx(0, 3, 768.into()) .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) - .with_resize_filter("Bilinear") - .with_normalize(true) } - pub fn florence2_textual() -> Self { - Self::florence2().with_model_kind(crate::Kind::Language) + pub fn florence2_base() -> Self { + Self::florence2() + .with_scale(crate::Scale::B) + .with_visual_file("base-vision-encoder.onnx") + .with_textual_file("base-embed-tokens.onnx") + .with_textual_encoder_file("base-encoder.onnx") + .with_textual_decoder_file("base-decoder.onnx") + .with_textual_decoder_merged_file("base-decoder-merged.onnx") + .with_tokenizer_file("florence2/tokenizer.json") + .with_config_file("florence2/config.json") + .with_special_tokens_map_file("florence2/special_tokens_map.json") + .with_tokenizer_config_file("florence2/tokenizer_config.json") } - pub fn florence2_visual_base() -> Self { - Self::florence2_visual().with_model_scale(crate::Scale::B) - } - - pub fn florence2_textual_base() -> Self { - Self::florence2_textual().with_model_scale(crate::Scale::B) - } - - pub fn florence2_visual_large() -> Self { - Self::florence2_visual().with_model_scale(crate::Scale::L) - } - - pub fn florence2_textual_large() -> Self { - Self::florence2_textual().with_model_scale(crate::Scale::L) - } - - pub fn florence2_visual_encoder_base() -> Self { - Self::florence2_visual_base().with_model_file("base-vision-encoder.onnx") - } - - pub fn florence2_textual_embed_base() -> Self { - Self::florence2_textual_base().with_model_file("base-embed-tokens.onnx") - } - - pub fn florence2_texual_encoder_base() -> Self { - Self::florence2_textual_base().with_model_file("base-encoder.onnx") - } - - pub fn florence2_texual_decoder_base() -> Self { - Self::florence2_textual_base().with_model_file("base-decoder.onnx") - } - - pub fn florence2_texual_decoder_merged_base() -> Self { - Self::florence2_textual_base().with_model_file("base-decoder-merged.onnx") + pub fn florence2_large() -> Self { + todo!() } } diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs index 6a56775..b1f68b0 100644 --- a/src/models/florence2/impl.rs +++ b/src/models/florence2/impl.rs @@ -4,51 +4,59 @@ use ndarray::{s, Axis}; use rayon::prelude::*; use crate::{ - elapsed, - models::{BaseModelTextual, BaseModelVisual, Quantizer}, - Hbb, Image, LogitsSampler, Options, Polygon, Scale, Task, Ts, Xs, X, Y, + elapsed, models::Quantizer, Engine, Hbb, Image, LogitsSampler, ModelConfig, Polygon, Processor, + Scale, Task, Ts, Xs, X, Y, }; #[derive(Debug, Builder)] pub struct Florence2 { - pub vision_encoder: BaseModelVisual, - pub text_embed: BaseModelTextual, - pub encoder: BaseModelTextual, - pub decoder: BaseModelTextual, - pub decoder_merged: BaseModelTextual, + pub vision_encoder: Engine, + pub text_embed: Engine, + pub encoder: Engine, + pub decoder: Engine, + pub decoder_merged: Engine, ts: Ts, quantizer: Quantizer, max_length: usize, eos_token_id: u32, decoder_start_token_id: u32, n_kvs: usize, + height: usize, + width: usize, + batch: usize, + processor: Processor, } impl Florence2 { - pub fn new( - options_vision_encoder: Options, - options_text_embed: Options, - options_encoder: Options, - options_decoder: Options, - options_decoder_merged: Options, - ) -> Result { - let vision_encoder = BaseModelVisual::new(options_vision_encoder)?; - let text_embed = BaseModelTextual::new(options_text_embed)?; - let encoder = BaseModelTextual::new(options_encoder)?; - let decoder = BaseModelTextual::new(options_decoder)?; - let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; + pub fn new(config: ModelConfig) -> Result { + let vision_encoder = Engine::try_from_config(&config.visual)?; + let text_embed = Engine::try_from_config(&config.textual)?; + let encoder = Engine::try_from_config(&config.textual_encoder)?; + let decoder = Engine::try_from_config(&config.textual_decoder)?; + let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?; + + let (batch, height, width) = ( + vision_encoder.batch().opt(), + vision_encoder.try_height().unwrap_or(&1024.into()).opt(), + vision_encoder.try_width().unwrap_or(&1024.into()).opt(), + ); + + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); + let quantizer = Quantizer::default(); let ts = Ts::merge(&[ - vision_encoder.engine().ts(), - text_embed.engine().ts(), - encoder.engine().ts(), - decoder.engine().ts(), - decoder_merged.engine().ts(), + vision_encoder.ts(), + text_embed.ts(), + encoder.ts(), + decoder.ts(), + decoder_merged.ts(), ]); let max_length = 1024; let eos_token_id = 2; let decoder_start_token_id = 2; - let n_kvs = match decoder.scale() { + let n_kvs = match config.scale { Some(Scale::B) => 6, Some(Scale::L) => 12, _ => unimplemented!(), @@ -66,6 +74,10 @@ impl Florence2 { eos_token_id, decoder_start_token_id, n_kvs, + batch, + height, + width, + processor, }) } @@ -97,12 +109,12 @@ impl Florence2 { .map(|im| { let text = Self::process_task(task, im.height() as _, im.width() as _) .prompt_for_florence2()?; - let ids = self.text_embed.processor().encode_text_ids(&text, true)?; + let ids = self.processor.encode_text_ids(&text, true)?; X::from(ids).insert_axis(0) }) .collect::, _>>()?; let x = X::concat(&xs, 0)?; - let xs = self.text_embed.inference(x.into())?; + let xs = self.text_embed.run(x.into())?; let x = xs[0].to_owned(); Ok(x) @@ -110,7 +122,10 @@ impl Florence2 { pub fn forward(&mut self, xs_visual: &[Image], x_textual: &Task) -> Result> { let visual_embeddings = elapsed!("visual-encode", self.ts, { - self.vision_encoder.encode(xs_visual)? + let xs = self.processor.process_images(xs_visual)?; + self.batch = xs_visual.len(); // update + let xs = self.vision_encoder.run(xs.into())?; + xs[0].to_owned() }); let textual_embedding = elapsed!("textual-encode", self.ts, { @@ -141,7 +156,7 @@ impl Florence2 { let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); // encoder - let last_hidden_state = self.encoder.inference(Xs::from(vec![ + let last_hidden_state = self.encoder.run(Xs::from(vec![ attention_mask.clone(), inputs_embeds.clone(), ]))?[0] @@ -150,7 +165,7 @@ impl Florence2 { // decoder let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]); let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn()); - let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + let mut decoder_outputs = self.decoder.run(Xs::from(vec![ attention_mask.clone(), last_hidden_state.clone(), inputs_embeds, @@ -215,7 +230,7 @@ impl Florence2 { // decode let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; - let inputs_embeds = &self.text_embed.inference(Xs::from(next_tokens))?[0].clone(); + let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone(); let use_cache = X::ones(&[1]); let mut xs = vec![ attention_mask.clone(), @@ -229,13 +244,13 @@ impl Florence2 { xs.push(encoder_kvs[i * 2 + 1].clone()); } xs.push(use_cache); - decoder_outputs = self.decoder_merged.inference(xs.into())?; + decoder_outputs = self.decoder_merged.run(xs.into())?; } // batch decode let texts = self - .text_embed - .processor() + // .text_embed + .processor .decode_tokens_batch(&token_ids, false)?; Ok(texts) @@ -416,10 +431,6 @@ impl Florence2 { Ok(ys) } - pub fn batch(&self) -> usize { - self.vision_encoder.batch() as _ - } - pub fn summary(&mut self) { self.ts.summary(); } diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs index ce7096d..1ff2b34 100644 --- a/src/models/grounding_dino/config.rs +++ b/src/models/grounding_dino/config.rs @@ -1,9 +1,8 @@ /// Model configuration for `GroundingDino` -impl crate::Options { +impl crate::ModelConfig { pub fn grounding_dino() -> Self { Self::default() - .with_model_name("grounding-dino") - .with_model_kind(crate::Kind::VisionLanguage) + .with_name("grounding-dino") .with_model_ixx(0, 0, 1.into()) // TODO: current onnx model does not support bs > 1 .with_model_ixx(0, 2, 800.into()) // TODO: matters .with_model_ixx(0, 3, 1200.into()) // TODO: matters @@ -11,9 +10,10 @@ impl crate::Options { .with_resize_filter("CatmullRom") .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) - .with_normalize(true) - .with_class_confs(&[0.25]) - .with_text_confs(&[0.25]) + .with_tokenizer_file("grounding-dino/tokenizer.json") + .with_config_file("grounding-dino/config.json") + .with_special_tokens_map_file("grounding-dino/special_tokens_map.json") + .with_tokenizer_config_file("grounding-dino/tokenizer_config.json") } pub fn grounding_dino_tiny() -> Self { diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs index d03e948..d1082d9 100644 --- a/src/models/grounding_dino/impl.rs +++ b/src/models/grounding_dino/impl.rs @@ -4,7 +4,7 @@ use ndarray::{s, Array2, Axis}; use rayon::prelude::*; use std::fmt::Write; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; #[derive(Builder, Debug)] pub struct GroundingDINO { @@ -24,8 +24,8 @@ pub struct GroundingDINO { } impl GroundingDINO { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -33,11 +33,8 @@ impl GroundingDINO { engine.try_width().unwrap_or(&1200.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let class_names = options + + let class_names = config .text_names .as_ref() .and_then(|v| { @@ -48,16 +45,20 @@ impl GroundingDINO { .collect(); (!v.is_empty()).then_some(v) }) - .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?; + .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the config. Ensure the 'text_names' field is non-empty and contains valid class names."))?; let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| { write!(&mut acc, "{}.", text).unwrap(); acc }); + + let confs_visual = DynConf::new(config.class_confs(), class_names.len()); + let confs_textual = DynConf::new(config.text_confs(), class_names.len()); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); let token_ids = processor.encode_text_ids(&text_prompt, true)?; let tokens = processor.encode_text_tokens(&text_prompt, true)?; let class_ids_map = Self::process_class_ids(&tokens); - let confs_visual = DynConf::new(options.class_confs(), class_names.len()); - let confs_textual = DynConf::new(options.text_confs(), class_names.len()); Ok(Self { engine, diff --git a/src/models/linknet/config.rs b/src/models/linknet/config.rs index c1c666a..9c952b1 100644 --- a/src/models/linknet/config.rs +++ b/src/models/linknet/config.rs @@ -1,8 +1,8 @@ /// Model configuration for [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/abs/1707.03718) -impl crate::Options { +impl crate::ModelConfig { pub fn linknet() -> Self { Self::fast() - .with_model_name("linknet") + .with_name("linknet") .with_image_mean(&[0.798, 0.785, 0.772]) .with_image_std(&[0.264, 0.2749, 0.287]) } diff --git a/src/models/mobileone/config.rs b/src/models/mobileone/config.rs index a76f271..1190d46 100644 --- a/src/models/mobileone/config.rs +++ b/src/models/mobileone/config.rs @@ -1,10 +1,10 @@ use crate::NAMES_IMAGENET_1K; /// Model configuration for `MobileOne` -impl crate::Options { +impl crate::ModelConfig { pub fn mobileone() -> Self { Self::default() - .with_model_name("mobileone") + .with_name("mobileone") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 224.into()) diff --git a/src/models/modnet/config.rs b/src/models/modnet/config.rs index 05174d2..c72dd53 100644 --- a/src/models/modnet/config.rs +++ b/src/models/modnet/config.rs @@ -1,8 +1,8 @@ /// Model configuration for `MODNet` -impl crate::Options { +impl crate::ModelConfig { pub fn modnet() -> Self { Self::default() - .with_model_name("modnet") + .with_name("modnet") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 2, (416, 512, 800).into()) .with_model_ixx(0, 3, (416, 512, 800).into()) diff --git a/src/models/modnet/impl.rs b/src/models/modnet/impl.rs index 5322ab1..9dff8b3 100644 --- a/src/models/modnet/impl.rs +++ b/src/models/modnet/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct MODNet { @@ -16,8 +16,8 @@ pub struct MODNet { } impl MODNet { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -25,8 +25,7 @@ impl MODNet { engine.try_width().unwrap_or(&512.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); diff --git a/src/models/moondream2/config.rs b/src/models/moondream2/config.rs index 96d0bf9..d5c7642 100644 --- a/src/models/moondream2/config.rs +++ b/src/models/moondream2/config.rs @@ -1,117 +1,47 @@ /// Model configuration for `moondream2` -impl crate::Options { +impl crate::ModelConfig { pub fn moondream2() -> Self { Self::default() - .with_model_name("moondream2") - .with_model_num_dry_run(0) + .with_name("moondream2") + .with_visual_encoder_ixx(0, 0, (1, 3, 4).into()) // patch count + .with_image_mean(&[0.5, 0.5, 0.5]) + .with_image_std(&[0.5, 0.5, 0.5]) + .with_resize_mode(crate::ResizeMode::FitExact) + .with_resize_filter("catmullrom") + .with_visual_projection_ixx(0, 0, 1.into()) + .with_textual_encoder_ixx(0, 0, 1.into()) + .with_textual_decoder_ixx(0, 0, 1.into()) + .with_size_encoder_ixx(0, 0, 1.into()) + .with_size_decoder_ixx(0, 0, 1.into()) + .with_coord_encoder_ixx(0, 0, 1.into()) + .with_coord_decoder_ixx(0, 0, 1.into()) + .with_tokenizer_file("moondream2/tokenizer.json") + .with_tokenizer_config_file("moondream2/tokenizer_config.json") } pub fn moondream2_0_5b() -> Self { - Self::moondream2().with_model_scale(crate::Scale::Billion(0.5)) + Self::moondream2() + .with_scale(crate::Scale::Billion(0.5)) + .with_visual_encoder_file("0.5b-vision-encoder.onnx") + .with_visual_projection_file("0.5b-vision-projection.onnx") + .with_textual_decoder_file("0.5b-text-decoder.onnx") + .with_textual_encoder_file("0.5b-text-encoder.onnx") + .with_coord_encoder_file("0.5b-coord-encoder.onnx") + .with_coord_decoder_file("0.5b-coord-decoder.onnx") + .with_size_encoder_file("0.5b-size-encoder.onnx") + .with_size_decoder_file("0.5b-size-decoder.onnx") } - pub fn moondream2_0_5b_vision_encoder() -> Self { - Self::moondream2_0_5b() - .with_model_ixx(0, 0, (1, 3, 4).into()) // patch count - .with_model_kind(crate::Kind::Vision) - .with_image_mean(&[0.5, 0.5, 0.5]) - .with_image_std(&[0.5, 0.5, 0.5]) - .with_normalize(true) - .with_resize_mode(crate::ResizeMode::FitExact) - .with_resize_filter("catmullrom") - .with_model_file("0.5b-vision-encoder.onnx") - } - - pub fn moondream2_0_5b_vision_projection() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_kind(crate::Kind::Vision) - .with_model_file("0.5b-vision-projection.onnx") - } - - pub fn moondream2_0_5b_text_decoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_kind(crate::Kind::Language) - .with_model_file("0.5b-text-decoder.onnx") - } - - pub fn moondream2_0_5b_text_encoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_kind(crate::Kind::Language) - .with_model_file("0.5b-text-encoder.onnx") - } - - pub fn moondream2_0_5b_coord_encoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_file("0.5b-coord-encoder.onnx") - } - - pub fn moondream2_0_5b_coord_decoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_file("0.5b-coord-decoder.onnx") - } - - pub fn moondream2_0_5b_size_encoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_file("0.5b-size-encoder.onnx") - } - - pub fn moondream2_0_5b_size_decoder() -> Self { - Self::moondream2_0_5b() - .with_batch_size(1) - .with_model_file("0.5b-size-decoder.onnx") - } - - pub fn moondream2_2b_vision_encoder() -> Self { - Self::moondream2_0_5b_vision_encoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-vision-encoder.onnx") - } - - pub fn moondream2_2b_vision_projection() -> Self { - Self::moondream2_0_5b_vision_projection() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-vision-projection.onnx") - } - - pub fn moondream2_2b_text_decoder() -> Self { - Self::moondream2_0_5b_text_decoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-text-decoder.onnx") - } - - pub fn moondream2_2b_text_encoder() -> Self { - Self::moondream2_0_5b_text_encoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-text-encoder.onnx") - } - - pub fn moondream2_2b_coord_encoder() -> Self { - Self::moondream2_0_5b_coord_encoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-coord-encoder.onnx") - } - - pub fn moondream2_2b_coord_decoder() -> Self { - Self::moondream2_0_5b_coord_decoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-coord-decoder.onnx") - } - - pub fn moondream2_2b_size_encoder() -> Self { - Self::moondream2_0_5b_size_encoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-size-encoder.onnx") - } - - pub fn moondream2_2b_size_decoder() -> Self { - Self::moondream2_0_5b_size_decoder() - .with_model_scale(crate::Scale::Billion(2.)) - .with_model_file("2b-size-decoder.onnx") + pub fn moondream2_2b() -> Self { + Self::moondream2() + .with_scale(crate::Scale::Billion(2.)) + .with_visual_encoder_file("2b-vision-encoder.onnx") + .with_visual_projection_file("2b-vision-projection.onnx") + .with_textual_decoder_file("2b-text-decoder.onnx") + .with_textual_encoder_file("2b-text-encoder.onnx") + .with_coord_encoder_file("2b-coord-encoder.onnx") + .with_coord_decoder_file("2b-coord-decoder.onnx") + .with_size_encoder_file("2b-size-encoder.onnx") + .with_size_decoder_file("2b-size-decoder.onnx") } } diff --git a/src/models/moondream2/impl.rs b/src/models/moondream2/impl.rs index 55a0942..30bab95 100644 --- a/src/models/moondream2/impl.rs +++ b/src/models/moondream2/impl.rs @@ -5,66 +5,57 @@ use ndarray::{s, Array, Array2, Array3, Axis, IxDyn}; use ndarray_npy::ReadNpyExt; use crate::{ - BaseModelTextual, DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, Options, Processor, - Scale, Task, Ts, Xs, X, Y, + DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, ModelConfig, Processor, Scale, Task, + Xs, X, Y, }; #[derive(Builder, Debug)] pub struct Moondream2 { - vision_encoder: VisionEncoder, - vision_projection: VisionProjection, - pub text_decoder: BaseModelTextual, - text_encoder: BaseModelTextual, - coord_decoder: Option, - coord_encoder: Option, - size_decoder: Option, - size_encoder: Option, + vision_encoder: Engine, + vision_projection: Engine, + text_decoder: Engine, + text_encoder: Engine, + coord_decoder: Option, + coord_encoder: Option, + size_decoder: Option, + size_encoder: Option, initial_kv_cache: X, // TODO: use f16 scale: Scale, dtype: DType, max_length: usize, eos_token_id: u32, max_objects: usize, + num_patch: usize, + patch_size: usize, + processor: Processor, + seq_len: usize, } impl Moondream2 { - // TODO - #[allow(clippy::too_many_arguments)] - pub fn new( - options_vision_encoder: Options, - options_vision_projection: Options, - options_text_encoder: Options, - options_text_decoder: Options, - options_coord_encoder: Option, - options_coord_decoder: Option, - options_size_encoder: Option, - options_size_decoder: Option, - ) -> Result { + pub fn new(config: ModelConfig) -> Result { let max_length = 2048; let max_objects = 50; let eos_token_id = 50256; - let dtype = options_vision_encoder.model_dtype; - let scale = options_vision_encoder - .model_scale - .clone() - .unwrap_or(Scale::Billion(0.5)); + let dtype = config.visual_encoder.dtype; + let scale = config.scale.clone().unwrap_or(Scale::Billion(0.5)); let initial_kv_cache: X = KVCache::new(&scale, &dtype)?.0.into(); - let vision_encoder = VisionEncoder::new(options_vision_encoder)?; - let vision_projection = VisionProjection::new(options_vision_projection)?; - let text_decoder = BaseModelTextual::new(options_text_decoder)?; - let text_encoder = BaseModelTextual::new(options_text_encoder)?; - let coord_decoder = options_coord_decoder - .map(BaseModelTextual::new) - .transpose()?; - let coord_encoder = options_coord_encoder - .map(BaseModelTextual::new) - .transpose()?; - let size_decoder = options_size_decoder - .map(BaseModelTextual::new) - .transpose()?; - let size_encoder = options_size_encoder - .map(BaseModelTextual::new) - .transpose()?; + let vision_encoder = Engine::try_from_config(&config.visual_encoder)?; + let vision_projection = Engine::try_from_config(&config.visual_projection)?; + let text_decoder = Engine::try_from_config(&config.textual_decoder)?; + let text_encoder = Engine::try_from_config(&config.textual_encoder)?; + let coord_decoder = Engine::try_from_config(&config.coord_decoder).ok(); + let coord_encoder = Engine::try_from_config(&config.coord_encoder).ok(); + let size_decoder = Engine::try_from_config(&config.size_decoder).ok(); + let size_encoder = Engine::try_from_config(&config.size_encoder).ok(); + let (num_patch, patch_size, _ts) = ( + vision_encoder.batch().opt(), + vision_encoder.try_height().unwrap_or(&378.into()).opt(), + vision_encoder.ts.clone(), + ); + let seq_len = vision_projection.inputs_minoptmax[0][1].opt(); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(patch_size as _) + .with_image_height(patch_size as _); Ok(Self { vision_encoder, @@ -81,12 +72,16 @@ impl Moondream2 { eos_token_id, scale, dtype, + num_patch, + patch_size, + processor, + seq_len, }) } pub fn encode_image(&mut self, x: &Image) -> Result { - let patches_emb = self.vision_encoder.encode(x)?.clone().insert_axis(0)?; - let image_embedding = self.vision_projection.inference(patches_emb.into())?[0].to_owned(); + let patches_emb = self.encode(x)?.clone().insert_axis(0)?; + let image_embedding = self.vision_projection.run(patches_emb.into())?[0].to_owned(); Ok(image_embedding) } @@ -119,12 +114,7 @@ impl Moondream2 { Task::Vqa(query) => { let input_ids: Vec<_> = [198., 198., 24361., 25.] .iter() - .chain( - &self - .text_encoder - .processor() - .encode_text_ids(query, false)?, - ) + .chain(&self.processor.encode_text_ids(query, false)?) .chain(&[198., 198., 33706., 25.]) .cloned() .collect(); @@ -139,8 +129,7 @@ impl Moondream2 { .iter() .chain( &self - .text_encoder - .processor() + .processor .encode_text_ids(&format!(" {}", object), false)?, ) .chain(&[628.]) @@ -156,8 +145,7 @@ impl Moondream2 { .iter() .chain( &self - .text_encoder - .processor() + .processor .encode_text_ids(&format!(" {}", object), false)?, ) .chain(&[628.]) @@ -174,10 +162,10 @@ impl Moondream2 { fn generate_text(&mut self, input_ids: &[f32], kv_cache: Array) -> Result { let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?; - let mut input_embeds = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned(); + let mut input_embeds = self.text_encoder.run(Xs::from(input_ids))?[0].to_owned(); let logits_sampler = LogitsSampler::new(); let mut token_ids: Vec = Vec::new(); - let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4]; + let mut pos = self.seq_len + self.initial_kv_cache.shape()[4]; let mut inc = input_embeds.shape()[1]; let mut kv_cache = kv_cache.clone(); @@ -192,7 +180,7 @@ impl Moondream2 { .into_dyn() .into(), ]); - let decoder_outputs = self.text_decoder.inference(input)?; + let decoder_outputs = self.text_decoder.run(input)?; // update let logits = &decoder_outputs["logits"]; @@ -221,13 +209,10 @@ impl Moondream2 { // encode let next_tokens = X::from(vec![token_id as f32]).insert_axis(1)?; - input_embeds = self.text_encoder.inference(Xs::from(next_tokens))?[0].to_owned(); + input_embeds = self.text_encoder.run(Xs::from(next_tokens))?[0].to_owned(); } - let text = self - .text_encoder - .processor() - .decode_tokens(&token_ids, true)?; + let text = self.processor.decode_tokens(&token_ids, true)?; Ok(text) } @@ -242,16 +227,16 @@ impl Moondream2 { let mut y_bboxes: Vec = Vec::new(); let mut y_kpts: Vec> = Vec::new(); let (image_height, image_width) = ( - self.vision_encoder.processor.images_transform_info[0].height_src, - self.vision_encoder.processor.images_transform_info[0].width_src, + self.processor.images_transform_info[0].height_src, + self.processor.images_transform_info[0].width_src, ); - let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4]; + let mut pos = self.seq_len + self.initial_kv_cache.shape()[4]; let logits_sampler = LogitsSampler::new(); // initial input_embeds let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?; - let mut hidden = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned(); + let mut hidden = self.text_encoder.run(Xs::from(input_ids))?[0].to_owned(); let mut kv_cache = kv_cache; // generate @@ -273,12 +258,7 @@ impl Moondream2 { // cx let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into(); - let cx = self - .coord_decoder - .as_mut() - .unwrap() - .inference(Xs::from(input))?[0] - .clone(); // [1024] + let cx = self.coord_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone(); // [1024] let ratio = cx.shape()[0] as f32; let cx = logits_sampler .decode(cx.as_slice().context("Failed to get slice for `cx`")?)? @@ -288,7 +268,7 @@ impl Moondream2 { .coord_encoder .as_mut() .unwrap() - .inference(Xs::from(X::from(vec![cx])))?[0] + .run(Xs::from(X::from(vec![cx])))?[0] .clone() .insert_axis(0)? .insert_axis(0)?; @@ -296,12 +276,7 @@ impl Moondream2 { // cy let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?; let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into(); - let cy = self - .coord_decoder - .as_mut() - .unwrap() - .inference(Xs::from(input))?[0] - .clone(); + let cy = self.coord_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone(); let ratio = cy.shape()[0] as f32; let cy = logits_sampler @@ -313,7 +288,7 @@ impl Moondream2 { .coord_encoder .as_mut() .unwrap() - .inference(Xs::from(X::from(vec![cy])))?[0] + .run(Xs::from(X::from(vec![cy])))?[0] .clone() .insert_axis(0)? .insert_axis(0)?; @@ -334,12 +309,7 @@ impl Moondream2 { // wh let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?; let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into(); - let size = self - .size_decoder - .as_mut() - .unwrap() - .inference(Xs::from(input))?[0] - .clone(); // [2, 1024] + let size = self.size_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone(); // [2, 1024] let ratio = size.shape()[1] as f32; let w = logits_sampler.decode( @@ -361,7 +331,7 @@ impl Moondream2 { .size_encoder .as_mut() .unwrap() - .inference(Xs::from(X::from(vec![w, h])))?[0] + .run(Xs::from(X::from(vec![w, h])))?[0] .clone() .insert_axis(0)? .insert_axis(0)?; // [1024] @@ -392,7 +362,7 @@ impl Moondream2 { } fn prepare_kv_cache(&mut self, image_embedding: &X) -> Result> { - let kv_cache_new = self.text_decoder.inference(Xs::from(vec![ + let kv_cache_new = self.text_decoder.run(Xs::from(vec![ image_embedding.clone(), self.initial_kv_cache.clone(), ]))?["new_kv_cache"] @@ -421,7 +391,7 @@ impl Moondream2 { kv_cache: &mut Array, pos: &mut usize, ) -> Result { - let decoder_outputs = self.text_decoder.inference(Xs::from(vec![ + let decoder_outputs = self.text_decoder.run(Xs::from(vec![ input_embeds.clone(), kv_cache .slice(s![.., .., .., .., ..*pos, ..]) @@ -442,38 +412,6 @@ impl Moondream2 { Ok(decoder_outputs["logits"].to_owned()) } -} - -#[derive(Debug, Builder)] -pub struct VisionEncoder { - engine: Engine, - num_patch: usize, - patch_size: usize, - processor: Processor, - ts: Ts, -} - -impl VisionEncoder { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; - let (num_patch, patch_size, ts) = ( - engine.batch().opt(), - engine.try_height().unwrap_or(&378.into()).opt(), - engine.ts.clone(), - ); - let processor = options - .to_processor()? - .with_image_width(patch_size as _) - .with_image_height(patch_size as _); - - Ok(Self { - engine, - patch_size, - num_patch, - processor, - ts, - }) - } fn create_patches(image: &Image, image_patch_size: usize) -> (Vec, (u32, u32)) { let mut patches = vec![image.clone()]; @@ -515,10 +453,6 @@ impl VisionEncoder { (patches, selected_template) } - pub fn inference(&mut self, xs: Xs) -> Result { - self.engine.run(xs) - } - pub fn encode(&mut self, x: &Image) -> Result { let (patches, selected_template) = Self::create_patches(x, self.patch_size); let patches = self.processor.process_images(&patches)?; @@ -526,7 +460,7 @@ impl VisionEncoder { (selected_template.0 as usize), (selected_template.1 as usize), ); - let patch_emb = self.inference(patches.clone().into())?[0].clone(); + let patch_emb = self.vision_encoder.run(patches.clone().into())?[0].clone(); let patch_emb = patch_emb.clone().0.into_dimensionality::()?; let patch_emb = Self::process_patch_emb(patch_emb, template)?; let patch_emb = X::from(patch_emb.into_dyn()); // TODO .insert_axis(x), @@ -608,30 +542,6 @@ impl VisionEncoder { } } -#[derive(Debug, Builder)] -pub struct VisionProjection { - engine: Engine, - seq_len: usize, - ts: Ts, -} - -impl VisionProjection { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; - let (seq_len, ts) = (engine.inputs_minoptmax[0][1].opt(), engine.ts.clone()); - - Ok(Self { - engine, - seq_len, - ts, - }) - } - - pub fn inference(&mut self, xs: Xs) -> Result { - self.engine.run(xs) - } -} - #[derive(Builder, Debug)] struct KVCache(pub Array); diff --git a/src/models/owl/config.rs b/src/models/owl/config.rs index cc17da9..2520c1b 100644 --- a/src/models/owl/config.rs +++ b/src/models/owl/config.rs @@ -1,11 +1,10 @@ /// Model configuration for `OWLv2` -impl crate::Options { +impl crate::ModelConfig { pub fn owlv2() -> Self { Self::default() - .with_model_name("owlv2") - .with_model_kind(crate::Kind::VisionLanguage) + .with_name("owlv2") // 1st & 3rd: text - .with_model_ixx(0, 0, (1, 1, 1).into()) // TODO + .with_model_ixx(0, 0, (1, 1, 1).into()) .with_model_ixx(0, 1, 1.into()) .with_model_ixx(2, 0, (1, 1, 1).into()) .with_model_ixx(2, 1, 1.into()) @@ -21,6 +20,7 @@ impl crate::Options { .with_normalize(true) .with_class_confs(&[0.1]) .with_model_num_dry_run(0) + .with_tokenizer_file("owlv2/tokenizer.json") } pub fn owlv2_base() -> Self { diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs index ca0ef67..2295c77 100644 --- a/src/models/owl/impl.rs +++ b/src/models/owl/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct OWLv2 { @@ -22,8 +22,8 @@ pub struct OWLv2 { } impl OWLv2 { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&960.into()).opt(), @@ -31,11 +31,7 @@ impl OWLv2 { engine.ts.clone(), ); let spec = engine.spec().to_owned(); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let names: Vec = options + let names: Vec = config .class_names() .expect("No class names specified.") .iter() @@ -44,7 +40,10 @@ impl OWLv2 { let names_with_prompt: Vec = names.iter().map(|x| format!("a photo of {}", x)).collect(); let n = names.len(); - let confs = DynConf::new(options.class_confs(), n); + let confs = DynConf::new(config.class_confs(), n); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); let input_ids: Vec = processor .encode_texts_ids( &names_with_prompt diff --git a/src/models/picodet/config.rs b/src/models/picodet/config.rs index b3988a5..3ca083c 100644 --- a/src/models/picodet/config.rs +++ b/src/models/picodet/config.rs @@ -4,11 +4,11 @@ use crate::{ }; /// Model configuration for `PicoDet` -impl crate::Options { +impl crate::ModelConfig { pub fn picodet() -> Self { Self::default() - .with_model_name("picodet") - .with_batch_size(1) // TODO: ONNX model's batch size seems always = 1 + .with_name("picodet") + .with_batch_size_all(1) // TODO: ONNX model's batch size seems always = 1 .with_model_ixx(0, 2, 640.into()) .with_model_ixx(0, 3, 640.into()) .with_model_ixx(1, 0, (1, 1, 8).into()) diff --git a/src/models/picodet/impl.rs b/src/models/picodet/impl.rs index 6bda13a..0a731e3 100644 --- a/src/models/picodet/impl.rs +++ b/src/models/picodet/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct PicoDet { @@ -19,8 +19,8 @@ pub struct PicoDet { } impl PicoDet { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&640.into()).opt(), @@ -28,15 +28,14 @@ impl PicoDet { engine.ts.clone(), ); let spec = engine.spec().to_owned(); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let names = options + let names = config .class_names() .expect("No class names are specified.") .to_vec(); - let confs = DynConf::new(options.class_confs(), names.len()); + let confs = DynConf::new(config.class_confs(), names.len()); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); Ok(Self { engine, diff --git a/src/models/pipeline/basemodel.rs b/src/models/pipeline/basemodel.rs index 062c84f..3a791f0 100644 --- a/src/models/pipeline/basemodel.rs +++ b/src/models/pipeline/basemodel.rs @@ -2,8 +2,7 @@ use aksr::Builder; use anyhow::Result; use crate::{ - elapsed, DType, Device, Engine, Image, Kind, Options, Processor, Scale, Task, Ts, Version, Xs, - X, + elapsed, DType, Device, Engine, Image, ModelConfig, Processor, Scale, Task, Ts, Version, Xs, X, }; #[derive(Debug, Builder)] @@ -20,7 +19,6 @@ pub struct BaseModelVisual { dtype: DType, task: Option, scale: Option, - kind: Option, version: Option, } @@ -29,8 +27,8 @@ impl BaseModelVisual { self.ts.summary(); } - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let err_msg = "You need to specify the image height and image width for visual model."; let (batch, height, width, ts, spec) = ( engine.batch().opt(), @@ -39,18 +37,16 @@ impl BaseModelVisual { engine.ts.clone(), engine.spec().to_owned(), ); - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let device = options.model_device; - let task = options.model_task; - let scale = options.model_scale; - let dtype = options.model_dtype; - let kind = options.model_kind; - let name = options.model_name; - let version = options.model_version; + let device = config.model.device; + let task = config.task; + let scale = config.scale; + let dtype = config.model.dtype; + let name = config.name; + let version = config.version; Ok(Self { engine, @@ -63,7 +59,6 @@ impl BaseModelVisual { dtype, task, scale, - kind, device, version, name, @@ -101,7 +96,6 @@ pub struct BaseModelTextual { dtype: DType, task: Option, scale: Option, - kind: Option, version: Option, } @@ -110,21 +104,20 @@ impl BaseModelTextual { self.ts.summary(); } - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, ts, spec) = ( engine.batch().opt(), engine.ts.clone(), engine.spec().to_owned(), ); - let processor = options.to_processor()?; - let device = options.model_device; - let task = options.model_task; - let scale = options.model_scale; - let dtype = options.model_dtype; - let kind = options.model_kind; - let name = options.model_name; - let version = options.model_version; + let processor = Processor::try_from_config(&config.processor)?; + let device = config.model.device; + let dtype = config.model.dtype; + let task = config.task; + let scale = config.scale; + let name = config.name; + let version = config.version; Ok(Self { engine, @@ -135,7 +128,6 @@ impl BaseModelTextual { dtype, task, scale, - kind, device, version, name, diff --git a/src/models/pipeline/image_classifier.rs b/src/models/pipeline/image_classifier.rs index 0dafd43..e7a9b12 100644 --- a/src/models/pipeline/image_classifier.rs +++ b/src/models/pipeline/image_classifier.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Image, Options, Prob, Processor, Ts, Xs, Y}; +use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Prob, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct ImageClassifier { @@ -20,11 +20,12 @@ pub struct ImageClassifier { spec: String, } -impl TryFrom for ImageClassifier { +impl TryFrom for ImageClassifier { type Error = anyhow::Error; - fn try_from(options: Options) -> Result { - let engine = options.to_engine()?; + fn try_from(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; + let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -32,11 +33,8 @@ impl TryFrom for ImageClassifier { engine.try_width().unwrap_or(&224.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let (nc, names) = match (options.nc(), options.class_names()) { + + let (nc, names) = match (config.nc(), config.class_names()) { (Some(nc), Some(names)) => { if nc != names.len() { anyhow::bail!( @@ -56,8 +54,11 @@ impl TryFrom for ImageClassifier { anyhow::bail!("Neither class names nor class numbers were specified."); } }; - let confs = DynConf::new(options.class_confs(), nc); - let apply_softmax = options.apply_softmax.unwrap_or_default(); + let confs = DynConf::new(config.class_confs(), nc); + let apply_softmax = config.apply_softmax.unwrap_or_default(); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); Ok(Self { engine, diff --git a/src/models/rfdetr/config.rs b/src/models/rfdetr/config.rs index 53a861a..85f1220 100644 --- a/src/models/rfdetr/config.rs +++ b/src/models/rfdetr/config.rs @@ -1,18 +1,17 @@ use crate::NAMES_COCO_91; /// Model configuration for `RT-DETR` -impl crate::Options { +impl crate::ModelConfig { pub fn rfdetr() -> Self { Self::default() - .with_model_name("rfdetr") - .with_batch_size(1) + .with_name("rfdetr") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 560.into()) .with_model_ixx(0, 3, 560.into()) .with_resize_mode(crate::ResizeMode::FitAdaptive) - .with_normalize(true) .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) - .with_class_confs(&[0.25]) .with_class_names(&NAMES_COCO_91) } diff --git a/src/models/rfdetr/impl.rs b/src/models/rfdetr/impl.rs index dcac505..6d15e52 100644 --- a/src/models/rfdetr/impl.rs +++ b/src/models/rfdetr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, Y}; #[derive(Debug, Builder)] pub struct RFDETR { @@ -19,8 +19,8 @@ pub struct RFDETR { } impl RFDETR { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&560.into()).opt(), @@ -28,16 +28,16 @@ impl RFDETR { engine.ts.clone(), ); let spec = engine.spec().to_owned(); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let names = options + let names: Vec = config .class_names() .expect("No class names specified.") - .to_vec(); - let confs = DynConf::new(options.class_confs(), names.len()); - + .iter() + .map(|x| x.to_string()) + .collect(); + let confs = DynConf::new(config.class_confs(), names.len()); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); Ok(Self { engine, height, diff --git a/src/models/rmbg/config.rs b/src/models/rmbg/config.rs index 65ae2bb..2bb963d 100644 --- a/src/models/rmbg/config.rs +++ b/src/models/rmbg/config.rs @@ -1,9 +1,10 @@ /// Model configuration for `RMBG` -impl crate::Options { +impl crate::ModelConfig { pub fn rmbg() -> Self { Self::default() - .with_model_name("rmbg") + .with_name("rmbg") .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 1024.into()) .with_model_ixx(0, 3, 1024.into()) } diff --git a/src/models/rmbg/impl.rs b/src/models/rmbg/impl.rs index f886fa6..ea4d9c6 100644 --- a/src/models/rmbg/impl.rs +++ b/src/models/rmbg/impl.rs @@ -1,7 +1,7 @@ use aksr::Builder; use anyhow::Result; -use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct RMBG { @@ -15,8 +15,8 @@ pub struct RMBG { } impl RMBG { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -24,8 +24,7 @@ impl RMBG { engine.try_width().unwrap_or(&1024.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); @@ -63,7 +62,6 @@ impl RMBG { fn postprocess(&mut self, xs: Xs) -> Result> { let mut ys: Vec = Vec::new(); for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() { - // image size let (h1, w1) = ( self.processor.images_transform_info[idx].height_src, self.processor.images_transform_info[idx].width_src, diff --git a/src/models/rtdetr/config.rs b/src/models/rtdetr/config.rs index 2b2ebe9..56f8da4 100644 --- a/src/models/rtdetr/config.rs +++ b/src/models/rtdetr/config.rs @@ -1,15 +1,15 @@ use crate::NAMES_COCO_80; /// Model configuration for `RT-DETR` -impl crate::Options { +impl crate::ModelConfig { pub fn rtdetr() -> Self { Self::default() - .with_model_name("rtdetr") - .with_batch_size(1) + .with_name("rtdetr") + .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 640.into()) .with_model_ixx(0, 3, 640.into()) .with_resize_mode(crate::ResizeMode::FitAdaptive) - .with_normalize(true) .with_class_confs(&[0.5]) .with_class_names(&NAMES_COCO_80) } diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs index ebef66b..53db03c 100644 --- a/src/models/rtdetr/impl.rs +++ b/src/models/rtdetr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct RTDETR { @@ -19,8 +19,8 @@ pub struct RTDETR { } impl RTDETR { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&640.into()).opt(), @@ -28,15 +28,16 @@ impl RTDETR { engine.ts.clone(), ); let spec = engine.spec().to_owned(); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let names = options + let names: Vec = config .class_names() .expect("No class names specified.") - .to_vec(); - let confs = DynConf::new(options.class_confs(), names.len()); + .iter() + .map(|x| x.to_string()) + .collect(); + let confs = DynConf::new(config.class_confs(), names.len()); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); Ok(Self { engine, @@ -87,7 +88,6 @@ impl RTDETR { .enumerate() .filter_map(|(idx, ((labels, boxes), scores))| { let ratio = self.processor.images_transform_info[idx].height_scale; - let mut y_bboxes = Vec::new(); for (i, &score) in scores.iter().enumerate() { let class_id = labels[i] as usize; @@ -102,7 +102,6 @@ impl RTDETR { xyxy[2] / ratio, xyxy[3] / ratio, ); - y_bboxes.push( Hbb::default() .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs index d223269..43bcb24 100644 --- a/src/models/rtmo/config.rs +++ b/src/models/rtmo/config.rs @@ -1,9 +1,10 @@ /// Model configuration for `RTMO` -impl crate::Options { +impl crate::ModelConfig { pub fn rtmo() -> Self { Self::default() - .with_model_name("rtmo") + .with_name("rtmo") .with_model_ixx(0, 0, 1.into()) + .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 640.into()) .with_model_ixx(0, 3, 640.into()) .with_resize_mode(crate::ResizeMode::FitAdaptive) diff --git a/src/models/rtmo/impl.rs b/src/models/rtmo/impl.rs index 1edb944..0b42d30 100644 --- a/src/models/rtmo/impl.rs +++ b/src/models/rtmo/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::Axis; -use crate::{elapsed, DynConf, Engine, Hbb, Image, Keypoint, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, DynConf, Engine, Hbb, Image, Keypoint, ModelConfig, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct RTMO { @@ -18,8 +18,8 @@ pub struct RTMO { } impl RTMO { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -27,15 +27,14 @@ impl RTMO { engine.try_width().unwrap_or(&512.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + + let nk = config.nk().unwrap_or(17); + let confs = DynConf::new(config.class_confs(), 1); + let kconfs = DynConf::new(config.keypoint_confs(), nk); + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let nk = options.nk().unwrap_or(17); - let confs = DynConf::new(options.class_confs(), 1); - let kconfs = DynConf::new(options.keypoint_confs(), nk); - Ok(Self { engine, height, diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs index 0e9ce58..f46e86c 100644 --- a/src/models/sam/config.rs +++ b/src/models/sam/config.rs @@ -1,100 +1,73 @@ -use crate::{models::SamKind, Options}; +use crate::{models::SamKind, ModelConfig}; /// Model configuration for `Segment Anything Model` -impl Options { +impl ModelConfig { pub fn sam() -> Self { Self::default() - .with_model_name("sam") - .with_model_ixx(0, 0, 1.into()) - } - - pub fn sam_encoder() -> Self { - Self::sam() - .with_model_ixx(0, 2, 1024.into()) - .with_model_ixx(0, 3, 1024.into()) + .with_name("sam") + .with_encoder_ixx(0, 0, 1.into()) + .with_encoder_ixx(0, 1, 3.into()) + .with_encoder_ixx(0, 2, 1024.into()) + .with_encoder_ixx(0, 3, 1024.into()) .with_resize_mode(crate::ResizeMode::FitAdaptive) .with_resize_filter("Bilinear") .with_image_mean(&[123.5, 116.5, 103.5]) .with_image_std(&[58.5, 57.0, 57.5]) .with_normalize(false) .with_sam_kind(SamKind::Sam) - .with_low_res_mask(false) + .with_sam_low_res_mask(false) .with_find_contours(true) } - pub fn sam_decoder() -> Self { + pub fn sam_v1_base() -> Self { Self::sam() + .with_encoder_file("sam-vit-b-encoder.onnx") + .with_decoder_file("sam-vit-b-decoder.onnx") } - pub fn sam_v1_base_encoder() -> Self { - Self::sam_encoder().with_model_file("sam-vit-b-encoder.onnx") + // pub fn sam_v1_base_singlemask_decoder() -> Self { + // Self::sam().with_decoder_file("sam-vit-b-decoder-singlemask.onnx") + // } + + pub fn sam2_tiny() -> Self { + Self::sam() + .with_encoder_file("sam2-hiera-tiny-encoder.onnx") + .with_sam_kind(SamKind::Sam2) + .with_decoder_file("sam2-hiera-tiny-decoder.onnx") } - pub fn sam_v1_base_decoder() -> Self { - Self::sam_decoder().with_model_file("sam-vit-b-decoder.onnx") - } - - pub fn sam_v1_base_singlemask_decoder() -> Self { - Self::sam_decoder().with_model_file("sam-vit-b-decoder-singlemask.onnx") - } - - pub fn sam2_tiny_encoder() -> Self { - Self::sam_encoder() - .with_model_file("sam2-hiera-tiny-encoder.onnx") + pub fn sam2_small() -> Self { + Self::sam() + .with_encoder_file("sam2-hiera-small-encoder.onnx") + .with_decoder_file("sam2-hiera-small-decoder.onnx") .with_sam_kind(SamKind::Sam2) } - pub fn sam2_tiny_decoder() -> Self { - Self::sam_decoder().with_model_file("sam2-hiera-tiny-decoder.onnx") - } - - pub fn sam2_small_encoder() -> Self { - Self::sam_encoder() - .with_model_file("sam2-hiera-small-encoder.onnx") + pub fn sam2_base_plus() -> Self { + Self::sam() + .with_encoder_file("sam2-hiera-base-plus-encoder.onnx") + .with_decoder_file("sam2-hiera-base-plus-decoder.onnx") .with_sam_kind(SamKind::Sam2) } - pub fn sam2_small_decoder() -> Self { - Self::sam_decoder().with_model_file("sam2-hiera-small-decoder.onnx") - } - - pub fn sam2_base_plus_encoder() -> Self { - Self::sam_encoder() - .with_model_file("sam2-hiera-base-plus-encoder.onnx") - .with_sam_kind(SamKind::Sam2) - } - - pub fn sam2_base_plus_decoder() -> Self { - Self::sam_decoder().with_model_file("sam2-hiera-base-plus-decoder.onnx") - } - - pub fn mobile_sam_tiny_encoder() -> Self { - Self::sam_encoder() - .with_model_file("mobile-sam-vit-t-encoder.onnx") + pub fn mobile_sam_tiny() -> Self { + Self::sam() + .with_encoder_file("mobile-sam-vit-t-encoder.onnx") .with_sam_kind(SamKind::MobileSam) + .with_decoder_file("mobile-sam-vit-t-decoder.onnx") } - pub fn mobile_sam_tiny_decoder() -> Self { - Self::sam_decoder().with_model_file("mobile-sam-vit-t-decoder.onnx") - } - - pub fn sam_hq_tiny_encoder() -> Self { - Self::sam_encoder() - .with_model_file("sam-hq-vit-t-encoder.onnx") + pub fn sam_hq_tiny() -> Self { + Self::sam() + .with_encoder_file("sam-hq-vit-t-encoder.onnx") .with_sam_kind(SamKind::SamHq) + .with_decoder_file("sam-hq-vit-t-decoder.onnx") } - pub fn sam_hq_tiny_decoder() -> Self { - Self::sam_decoder().with_model_file("sam-hq-vit-t-decoder.onnx") - } - - pub fn edge_sam_3x_encoder() -> Self { - Self::sam_encoder() - .with_model_file("edge-sam-3x-encoder.onnx") + pub fn edge_sam_3x() -> Self { + Self::sam() + .with_encoder_file("edge-sam-3x-encoder.onnx") + .with_decoder_file("edge-sam-3x-decoder.onnx") .with_sam_kind(SamKind::EdgeSam) } - - pub fn edge_sam_3x_decoder() -> Self { - Self::sam_decoder().with_model_file("edge-sam-3x-decoder.onnx") - } } diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs index 0e02ca5..fea747d 100644 --- a/src/models/sam/impl.rs +++ b/src/models/sam/impl.rs @@ -4,8 +4,8 @@ use ndarray::{s, Axis}; use rand::prelude::*; use crate::{ - elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X, - Y, + elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, SamPrompt, Ts, Xs, + X, Y, }; #[derive(Debug, Clone)] @@ -49,9 +49,10 @@ pub struct SAM { } impl SAM { - pub fn new(options_encoder: Options, options_decoder: Options) -> Result { - let encoder = options_encoder.to_engine()?; - let decoder = options_decoder.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let encoder = Engine::try_from_config(&config.encoder)?; + let decoder = Engine::try_from_config(&config.decoder)?; + let (batch, height, width) = ( encoder.batch().opt(), encoder.try_height().unwrap_or(&1024.into()).opt(), @@ -60,24 +61,23 @@ impl SAM { let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); let spec = encoder.spec().to_owned(); - let processor = options_encoder - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - - let conf = DynConf::new(options_encoder.class_confs(), 1); - let find_contours = options_encoder.find_contours; - let kind = match options_encoder.sam_kind { + let conf = DynConf::new(config.class_confs(), 1); + let find_contours = config.find_contours; + let kind = match config.sam_kind { Some(x) => x, None => anyhow::bail!("Error: no clear `SamKind` specified."), }; let use_low_res_mask = match kind { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { - options_encoder.low_res_mask.unwrap_or(false) + config.sam_low_res_mask.unwrap_or(false) } SamKind::EdgeSam | SamKind::Sam2 => true, }; + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); + Ok(Self { encoder, decoder, diff --git a/src/models/sam2/config.rs b/src/models/sam2/config.rs index db9df28..f58f7a7 100644 --- a/src/models/sam2/config.rs +++ b/src/models/sam2/config.rs @@ -1,50 +1,28 @@ -use crate::Options; +use crate::ModelConfig; /// Model configuration for `SAM2.1` -impl Options { - pub fn sam2_encoder() -> Self { +impl ModelConfig { + pub fn sam2_1_tiny() -> Self { Self::sam() - .with_model_ixx(0, 2, 1024.into()) - .with_model_ixx(0, 3, 1024.into()) - .with_resize_mode(crate::ResizeMode::FitAdaptive) - .with_resize_filter("Bilinear") - .with_image_mean(&[0.485, 0.456, 0.406]) - .with_image_std(&[0.229, 0.224, 0.225]) + .with_encoder_file("sam2.1-hiera-tiny-encoder.onnx") + .with_decoder_file("sam2.1-hiera-tiny-decoder.onnx") } - pub fn sam2_decoder() -> Self { + pub fn sam2_1_small() -> Self { Self::sam() + .with_encoder_file("sam2.1-hiera-small-encoder.onnx") + .with_decoder_file("sam2.1-hiera-small-decoder.onnx") } - pub fn sam2_1_tiny_encoder() -> Self { - Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx") + pub fn sam2_1_base_plus() -> Self { + Self::sam() + .with_encoder_file("sam2.1-hiera-base-plus-encoder.onnx") + .with_decoder_file("sam2.1-hiera-base-plus-decoder.onnx") } - pub fn sam2_1_tiny_decoder() -> Self { - Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx") - } - - pub fn sam2_1_small_encoder() -> Self { - Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx") - } - - pub fn sam2_1_small_decoder() -> Self { - Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx") - } - - pub fn sam2_1_base_plus_encoder() -> Self { - Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx") - } - - pub fn sam2_1_base_plus_decoder() -> Self { - Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx") - } - - pub fn sam2_1_large_encoder() -> Self { - Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx") - } - - pub fn sam2_1_large_decoder() -> Self { - Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx") + pub fn sam2_1_large() -> Self { + Self::sam() + .with_encoder_file("sam2.1-hiera-large-encoder.onnx") + .with_decoder_file("sam2.1-hiera-large-decoder.onnx") } } diff --git a/src/models/sam2/impl.rs b/src/models/sam2/impl.rs index cb32dce..9f576dc 100644 --- a/src/models/sam2/impl.rs +++ b/src/models/sam2/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use crate::{ - elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y, + elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Processor, SamPrompt, Ts, Xs, X, Y, }; #[derive(Builder, Debug)] @@ -20,9 +20,9 @@ pub struct SAM2 { } impl SAM2 { - pub fn new(options_encoder: Options, options_decoder: Options) -> Result { - let encoder = options_encoder.to_engine()?; - let decoder = options_decoder.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let encoder = Engine::try_from_config(&config.encoder)?; + let decoder = Engine::try_from_config(&config.decoder)?; let (batch, height, width) = ( encoder.batch().opt(), encoder.try_height().unwrap_or(&1024.into()).opt(), @@ -30,11 +30,11 @@ impl SAM2 { ); let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); let spec = encoder.spec().to_owned(); - let processor = options_encoder - .to_processor()? + + let conf = DynConf::new(config.class_confs(), 1); + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let conf = DynConf::new(options_encoder.class_confs(), 1); Ok(Self { encoder, diff --git a/src/models/sapiens/config.rs b/src/models/sapiens/config.rs index 053cf07..a51fbc7 100644 --- a/src/models/sapiens/config.rs +++ b/src/models/sapiens/config.rs @@ -1,15 +1,14 @@ use crate::NAMES_BODY_PARTS_28; /// Model configuration for `Sapiens` -impl crate::Options { +impl crate::ModelConfig { pub fn sapiens() -> Self { Self::default() - .with_model_name("sapiens") + .with_name("sapiens") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 2, 1024.into()) .with_model_ixx(0, 3, 768.into()) .with_resize_mode(crate::ResizeMode::FitExact) - .with_resize_filter("Bilinear") .with_image_mean(&[123.5, 116.5, 103.5]) .with_image_std(&[58.5, 57.0, 57.5]) .with_normalize(false) @@ -17,31 +16,11 @@ impl crate::Options { pub fn sapiens_body_part_segmentation() -> Self { Self::sapiens() - .with_model_task(crate::Task::InstanceSegmentation) + .with_task(crate::Task::InstanceSegmentation) .with_class_names(&NAMES_BODY_PARTS_28) } pub fn sapiens_seg_0_3b() -> Self { Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b.onnx") } - - // pub fn sapiens_seg_0_3b_uint8() -> Self { - // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-uint8.onnx") - // } - - // pub fn sapiens_seg_0_3b_fp16() -> Self { - // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-fp16.onnx") - // } - - // pub fn sapiens_seg_0_3b_bnb4() -> Self { - // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-bnb4.onnx") - // } - - // pub fn sapiens_seg_0_3b_q4f16() -> Self { - // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-q4f16.onnx") - // } - - // pub fn sapiens_seg_0_6b_fp16() -> Self { - // Self::sapiens_body_part_segmentation().with_model_file("seg-0.6b-fp16.onnx") - // } } diff --git a/src/models/sapiens/impl.rs b/src/models/sapiens/impl.rs index 65ea6fd..a192bd5 100644 --- a/src/models/sapiens/impl.rs +++ b/src/models/sapiens/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Array2, Axis}; -use crate::{elapsed, Engine, Image, Mask, Ops, Options, Polygon, Processor, Task, Ts, Xs, Y}; +use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, Task, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct Sapiens { @@ -18,8 +18,8 @@ pub struct Sapiens { } impl Sapiens { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -27,12 +27,12 @@ impl Sapiens { engine.try_width().unwrap_or(&768.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + + let task = config.task.expect("No sapiens task specified."); + let names_body = config.class_names; + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let task = options.model_task.expect("No sapiens task specified."); - let names_body = options.class_names; Ok(Self { engine, diff --git a/src/models/slanet/config.rs b/src/models/slanet/config.rs index f29b311..d045fdc 100644 --- a/src/models/slanet/config.rs +++ b/src/models/slanet/config.rs @@ -1,14 +1,14 @@ /// Model configuration for `SLANet` -impl crate::Options { +impl crate::ModelConfig { pub fn slanet() -> Self { Self::default() - .with_model_name("slanet") + .with_name("slanet") .with_model_ixx(0, 0, (1, 1, 8).into()) + .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, (320, 488, 488).into()) .with_model_ixx(0, 3, (320, 488, 488).into()) .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) - .with_normalize(true) .with_resize_mode(crate::ResizeMode::FitAdaptive) .with_padding_value(0) .with_unsigned(true) @@ -17,6 +17,6 @@ impl crate::Options { pub fn slanet_lcnet_v2_mobile_ch() -> Self { Self::slanet() .with_model_file("v2-mobile-ch.onnx") - .with_vocab_txt("vocab-sla-v2.txt") + .with_vocab_txt("slanet/vocab-sla-v2.txt") } } diff --git a/src/models/slanet/impl.rs b/src/models/slanet/impl.rs index fc2524e..54da596 100644 --- a/src/models/slanet/impl.rs +++ b/src/models/slanet/impl.rs @@ -2,7 +2,7 @@ use aksr::Builder; use anyhow::Result; use ndarray::{s, Axis}; -use crate::{elapsed, models::BaseModelVisual, Image, Keypoint, Options, Ts, Xs, Y}; +use crate::{elapsed, models::BaseModelVisual, Image, Keypoint, ModelConfig, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct SLANet { @@ -19,8 +19,8 @@ impl SLANet { self.ts.summary(); } - pub fn new(options: Options) -> Result { - let base = BaseModelVisual::new(options)?; + pub fn new(config: ModelConfig) -> Result { + let base = BaseModelVisual::new(config)?; let spec = base.engine().spec().to_owned(); let sos = 0; let eos = base.processor().vocab().len() - 1; diff --git a/src/models/smolvlm/config.rs b/src/models/smolvlm/config.rs index 1aa4c9c..339f115 100644 --- a/src/models/smolvlm/config.rs +++ b/src/models/smolvlm/config.rs @@ -1,58 +1,28 @@ /// Model configuration for `SmolVLM` -impl crate::Options { +impl crate::ModelConfig { pub fn smolvlm() -> Self { Self::default() - .with_batch_size(1) - .with_model_name("smolvlm") - .with_model_num_dry_run(3) - } - - pub fn smolvlm_vision() -> Self { - Self::smolvlm() - .with_model_kind(crate::Kind::Vision) + .with_name("smolvlm") + .with_batch_size_all(1) .with_image_mean(&[0.5, 0.5, 0.5]) .with_image_std(&[0.5, 0.5, 0.5]) .with_resize_filter("lanczos3") - .with_normalize(true) + .with_tokenizer_file("smolvlm/tokenizer.json") } - pub fn smolvlm_text() -> Self { - Self::smolvlm().with_model_kind(crate::Kind::Language) + pub fn smolvlm_256m() -> Self { + Self::smolvlm() + .with_scale(crate::Scale::Million(256.)) + .with_visual_file("256m-vision-encoder.onnx") + .with_textual_file("256m-embed-tokens.onnx") + .with_textual_decoder_file("256m-decoder-model-merged.onnx") } - pub fn smolvlm_vision_256m() -> Self { - Self::smolvlm_vision() - .with_model_scale(crate::Scale::Million(256.)) - .with_model_file("256m-vision-encoder.onnx") - } - - pub fn smolvlm_text_embed_256m() -> Self { - Self::smolvlm_text() - .with_model_scale(crate::Scale::Million(256.)) - .with_model_file("256m-embed-tokens.onnx") - } - - pub fn smolvlm_decoder_256m() -> Self { - Self::smolvlm_text() - .with_model_scale(crate::Scale::Million(256.)) - .with_model_file("256m-decoder-model-merged.onnx") - } - - pub fn smolvlm_vision_500m() -> Self { - Self::smolvlm_vision() - .with_model_scale(crate::Scale::Million(500.)) - .with_model_file("500m-vision-encoder.onnx") - } - - pub fn smolvlm_text_embed_500m() -> Self { - Self::smolvlm_text() - .with_model_scale(crate::Scale::Million(500.)) - .with_model_file("500m-embed-tokens.onnx") - } - - pub fn smolvlm_decoder_500m() -> Self { - Self::smolvlm_text() - .with_model_scale(crate::Scale::Million(500.)) - .with_model_file("500m-decoder-model-merged.onnx") + pub fn smolvlm_500m() -> Self { + Self::smolvlm() + .with_scale(crate::Scale::Million(500.)) + .with_visual_file("500m-vision-encoder.onnx") + .with_textual_file("500m-embed-tokens.onnx") + .with_textual_decoder_file("500m-decoder-model-merged.onnx") } } diff --git a/src/models/smolvlm/impl.rs b/src/models/smolvlm/impl.rs index 9cb09e8..48cf3cc 100644 --- a/src/models/smolvlm/impl.rs +++ b/src/models/smolvlm/impl.rs @@ -3,15 +3,13 @@ use anyhow::Result; use image::GenericImageView; use ndarray::s; -use crate::{ - models::BaseModelTextual, Engine, Image, LogitsSampler, Options, Processor, Scale, Ts, Xs, X, Y, -}; +use crate::{Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y}; #[derive(Debug, Builder)] pub struct SmolVLM { - vision: VisionEncoder, - text_embed: BaseModelTextual, - decoder: BaseModelTextual, + vision: Engine, + text_embed: Engine, + decoder: Engine, scale: Scale, image_token: String, global_img_token: String, @@ -25,17 +23,20 @@ pub struct SmolVLM { num_hidden_layers: usize, head_dim: usize, num_key_value_heads: usize, + num_patch: usize, + batch: usize, + width: usize, + height: usize, + processor: Processor, + ts: Ts, } impl SmolVLM { - pub fn new( - options_vision_encoder: Options, - options_text_embed: Options, - options_decode: Options, - ) -> Result { - let vision = VisionEncoder::new(options_vision_encoder)?; - let text_embed = BaseModelTextual::new(options_text_embed)?; - let decoder = BaseModelTextual::new(options_decode)?; + pub fn new(config: ModelConfig) -> Result { + let vision = Engine::try_from_config(&config.visual)?; + let text_embed = Engine::try_from_config(&config.textual)?; + let decoder = Engine::try_from_config(&config.textual_decoder_merged)?; + let fake_image_token = "".to_string(); let image_token = "".to_string(); let global_img_token = "".to_string(); @@ -45,12 +46,23 @@ impl SmolVLM { let image_token_id = 49190; let image_seq_len = 64; let max_length = 1024; - let (num_hidden_layers, head_dim, num_key_value_heads) = match decoder.scale() { + let (num_hidden_layers, head_dim, num_key_value_heads) = match &config.scale { Some(Scale::Million(256.)) => (30, 64, 3), Some(Scale::Million(500.)) => (32, 64, 5), _ => unimplemented!(), }; - let scale = (*decoder.scale().unwrap()).clone(); + let scale = config.scale.clone().unwrap(); + + let (batch, num_patch, height, width, ts) = ( + vision.batch().opt(), + vision.inputs_minoptmax()[0][1].opt(), + vision.inputs_minoptmax()[0][3].opt(), + vision.inputs_minoptmax()[0][4].opt(), + vision.ts.clone(), + ); + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); Ok(Self { vision, @@ -69,6 +81,12 @@ impl SmolVLM { bos_token, eos_token, image_seq_len, + batch, + num_patch, + height, + width, + ts, + processor, }) } @@ -86,13 +104,13 @@ impl SmolVLM { let bs = 1; // TODO // patches and pixel_attention_mask - let (patches, nw_nh) = self.vision.process_one(image)?; + let (patches, nw_nh) = self.process_one(image)?; let dims = patches.dims(); let pixel_attention_mask = X::ones(&[dims[0], dims[1], dims[3], dims[4]]); // input ids let prompt = self.image_prompt_string(nw_nh, text); - let mut input_ids: Vec = self.text_embed.processor().encode_text_ids(&prompt, true)?; + let mut input_ids: Vec = self.processor.encode_text_ids(&prompt, true)?; // position ids let mut position_ids = X::from( @@ -114,12 +132,11 @@ impl SmolVLM { for ii in 0..self.max_length { // inputs embeds let input_ids_x = X::from(input_ids.clone()).insert_axis(0)?; - let mut inputs_embeds = - self.text_embed.inference(input_ids_x.clone().into())?[0].clone(); + let mut inputs_embeds = self.text_embed.run(input_ids_x.clone().into())?[0].clone(); // encode image and merge if ii == 0 { - let image_features = self.vision.inference(Xs::from(vec![ + let image_features = self.vision.run(Xs::from(vec![ patches.clone(), pixel_attention_mask.clone(), ]))?[0] @@ -152,7 +169,7 @@ impl SmolVLM { } // decode - let decoder_outputs = self.decoder.inference(xs.into())?; + let decoder_outputs = self.decoder.run(xs.into())?; let logits = &decoder_outputs[0]; past_key_values = (1..decoder_outputs.len()) .step_by(2) @@ -186,10 +203,7 @@ impl SmolVLM { } // decode tokens - let text = self - .text_embed - .processor() - .decode_tokens(&token_ids, true)?; + let text = self.processor.decode_tokens(&token_ids, true)?; Ok(text) } @@ -233,44 +247,6 @@ impl SmolVLM { } } } -} - -#[derive(Debug, Builder)] -pub struct VisionEncoder { - engine: Engine, - num_patch: usize, - batch: usize, - width: usize, - height: usize, - processor: Processor, - ts: Ts, -} - -impl VisionEncoder { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; - let (batch, num_patch, height, width, ts) = ( - engine.batch().opt(), - engine.inputs_minoptmax()[0][1].opt(), - engine.inputs_minoptmax()[0][3].opt(), - engine.inputs_minoptmax()[0][4].opt(), - engine.ts.clone(), - ); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - - Ok(Self { - engine, - num_patch, - batch, - width, - height, - processor, - ts, - }) - } fn create_patches(image: &Image, patch_size: (u32, u32)) -> (Vec, (u32, u32)) { let mut patches = vec![]; @@ -307,10 +283,6 @@ impl VisionEncoder { (patches, (nw, nh)) } - pub fn inference(&mut self, xs: Xs) -> Result { - self.engine.run(xs) - } - pub fn process_one(&mut self, x: &Image) -> Result<(X, (u32, u32))> { let (patches, nw_nh) = Self::create_patches(x, (self.width as _, self.height as _)); let patches = self.processor.process_images(&patches)?.insert_axis(0)?; diff --git a/src/models/svtr/config.rs b/src/models/svtr/config.rs index 8f0c06e..583b7f5 100644 --- a/src/models/svtr/config.rs +++ b/src/models/svtr/config.rs @@ -1,8 +1,8 @@ /// Model configuration for `SVTR` -impl crate::Options { +impl crate::ModelConfig { pub fn svtr() -> Self { Self::default() - .with_model_name("svtr") + .with_name("svtr") .with_model_ixx(0, 0, (1, 1, 8).into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 48.into()) @@ -14,11 +14,11 @@ impl crate::Options { } pub fn svtr_ch() -> Self { - Self::svtr().with_vocab_txt("vocab-v1-ppocr-rec-ch.txt") + Self::svtr().with_vocab_txt("svtr/vocab-v1-ppocr-rec-ch.txt") } pub fn svtr_en() -> Self { - Self::svtr().with_vocab_txt("vocab-v1-ppocr-rec-en.txt") + Self::svtr().with_vocab_txt("svtr/vocab-v1-ppocr-rec-en.txt") } pub fn ppocr_rec_v3_ch() -> Self { diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs index eff6bda..728c7bb 100644 --- a/src/models/svtr/impl.rs +++ b/src/models/svtr/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::Axis; use rayon::prelude::*; -use crate::{elapsed, DynConf, Engine, Image, Options, Processor, Ts, Xs, Y}; +use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Processor, Ts, Xs, Y}; #[derive(Builder, Debug)] pub struct SVTR { @@ -18,18 +18,17 @@ pub struct SVTR { } impl SVTR { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&960.into()).opt(), engine.try_width().unwrap_or(&960.into()).opt(), engine.ts.clone(), ); - let spec = options.model_spec().to_string(); - let confs = DynConf::new(options.class_confs(), 1); - let processor = options - .to_processor()? + let spec = config.model.spec.to_string(); + let confs = DynConf::new(config.class_confs(), 1); + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); if processor.vocab().is_empty() { diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs index 8343434..291c8e6 100644 --- a/src/models/trocr/config.rs +++ b/src/models/trocr/config.rs @@ -1,92 +1,52 @@ use crate::Scale; /// Model configuration for `TrOCR` -impl crate::Options { +impl crate::ModelConfig { pub fn trocr() -> Self { - Self::default().with_model_name("trocr").with_batch_size(1) - } - - pub fn trocr_visual() -> Self { - Self::trocr() - .with_model_kind(crate::Kind::Vision) - .with_model_ixx(0, 1, 3.into()) - .with_model_ixx(0, 2, 384.into()) - .with_model_ixx(0, 3, 384.into()) + Self::default() + .with_name("trocr") + .with_batch_size_all(1) + .with_visual_ixx(0, 1, 3.into()) + .with_visual_ixx(0, 2, 384.into()) + .with_visual_ixx(0, 3, 384.into()) .with_image_mean(&[0.5, 0.5, 0.5]) .with_image_std(&[0.5, 0.5, 0.5]) - .with_resize_filter("Bilinear") - .with_normalize(true) + .with_resize_filter("lanczos3") + .with_tokenizer_file("trocr/tokenizer.json") + .with_config_file("trocr/config.json") + .with_special_tokens_map_file("trocr/special_tokens_map.json") + .with_tokenizer_config_file("trocr/tokenizer_config.json") } - pub fn trocr_textual() -> Self { - Self::trocr().with_model_kind(crate::Kind::Language) - } - - pub fn trocr_visual_small() -> Self { - Self::trocr_visual().with_model_scale(Scale::S) - } - - pub fn trocr_textual_small() -> Self { - Self::trocr_textual() - .with_model_scale(Scale::S) + pub fn trocr_small_printed() -> Self { + Self::trocr() + .with_scale(Scale::S) + .with_visual_file("s-encoder-printed.onnx") + .with_textual_decoder_file("s-decoder-printed.onnx") + .with_textual_decoder_merged_file("s-decoder-merged-printed.onnx") .with_tokenizer_file("trocr/tokenizer-small.json") } - pub fn trocr_visual_base() -> Self { - Self::trocr_visual().with_model_scale(Scale::B) - } - - pub fn trocr_textual_base() -> Self { - Self::trocr_textual() - .with_model_scale(Scale::B) + pub fn trocr_base_handwritten() -> Self { + Self::trocr() + .with_scale(Scale::B) + .with_visual_file("b-encoder-handwritten.onnx") + .with_textual_decoder_file("b-decoder-handwritten.onnx") + .with_textual_decoder_merged_file("b-decoder-merged-handwritten.onnx") .with_tokenizer_file("trocr/tokenizer-base.json") } - pub fn trocr_encoder_small_printed() -> Self { - Self::trocr_visual_small().with_model_file("s-encoder-printed.onnx") + pub fn trocr_small_handwritten() -> Self { + Self::trocr_small_printed() + .with_visual_file("s-encoder-handwritten.onnx") + .with_textual_decoder_file("s-decoder-handwritten.onnx") + .with_textual_decoder_merged_file("s-decoder-merged-handwritten.onnx") } - pub fn trocr_decoder_small_printed() -> Self { - Self::trocr_textual_small().with_model_file("s-decoder-printed.onnx") - } - - pub fn trocr_decoder_merged_small_printed() -> Self { - Self::trocr_textual_small().with_model_file("s-decoder-merged-printed.onnx") - } - - pub fn trocr_encoder_small_handwritten() -> Self { - Self::trocr_visual_small().with_model_file("s-encoder-handwritten.onnx") - } - - pub fn trocr_decoder_small_handwritten() -> Self { - Self::trocr_textual_small().with_model_file("s-decoder-handwritten.onnx") - } - - pub fn trocr_decoder_merged_small_handwritten() -> Self { - Self::trocr_textual_small().with_model_file("s-decoder-merged-handwritten.onnx") - } - - pub fn trocr_encoder_base_printed() -> Self { - Self::trocr_visual_base().with_model_file("b-encoder-printed.onnx") - } - - pub fn trocr_decoder_base_printed() -> Self { - Self::trocr_textual_base().with_model_file("b-decoder-printed.onnx") - } - - pub fn trocr_decoder_merged_base_printed() -> Self { - Self::trocr_textual_base().with_model_file("b-decoder-merged-printed.onnx") - } - - pub fn trocr_encoder_base_handwritten() -> Self { - Self::trocr_visual_base().with_model_file("b-encoder-handwritten.onnx") - } - - pub fn trocr_decoder_base_handwritten() -> Self { - Self::trocr_textual_base().with_model_file("b-decoder-handwritten.onnx") - } - - pub fn trocr_decoder_merged_base_handwritten() -> Self { - Self::trocr_textual_base().with_model_file("b-decoder-merged-handwritten.onnx") + pub fn trocr_base_printed() -> Self { + Self::trocr_base_handwritten() + .with_visual_file("b-encoder-printed.onnx") + .with_textual_decoder_file("b-decoder-printed.onnx") + .with_textual_decoder_merged_file("b-decoder-merged-printed.onnx") } } diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs index a0759a7..e7a981e 100644 --- a/src/models/trocr/impl.rs +++ b/src/models/trocr/impl.rs @@ -3,11 +3,7 @@ use anyhow::Result; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::{ - elapsed, - models::{BaseModelTextual, BaseModelVisual}, - Image, LogitsSampler, Options, Scale, Ts, Xs, X, Y, -}; +use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y}; #[derive(Debug, Copy, Clone)] pub enum TrOCRKind { @@ -29,35 +25,31 @@ impl TryFrom<&str> for TrOCRKind { #[derive(Debug, Builder)] pub struct TrOCR { - encoder: BaseModelVisual, - decoder: BaseModelTextual, - decoder_merged: BaseModelTextual, + encoder: Engine, + decoder: Engine, + decoder_merged: Engine, max_length: u32, eos_token_id: u32, decoder_start_token_id: u32, ts: Ts, n_kvs: usize, + processor: Processor, + batch: usize, + height: usize, + width: usize, } impl TrOCR { - pub fn summary(&self) { - self.ts.summary(); - } - - pub fn new( - options_encoder: Options, - options_decoder: Options, - options_decoder_merged: Options, - ) -> Result { - let encoder = BaseModelVisual::new(options_encoder)?; - let decoder = BaseModelTextual::new(options_decoder)?; - let decoder_merged = BaseModelTextual::new(options_decoder_merged)?; - let ts = Ts::merge(&[ - encoder.engine().ts(), - decoder.engine().ts(), - decoder_merged.engine().ts(), - ]); - + pub fn new(config: ModelConfig) -> Result { + let encoder = Engine::try_from_config(&config.visual)?; + let decoder = Engine::try_from_config(&config.textual_decoder)?; + let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?; + let (batch, height, width) = ( + encoder.batch().opt(), + encoder.try_height().unwrap_or(&384.into()).opt(), + encoder.try_width().unwrap_or(&384.into()).opt(), + ); + let ts = Ts::merge(&[encoder.ts(), decoder.ts(), decoder_merged.ts()]); // "bos_token": "", "eos_token": "", "sep_token": "", // "model_max_length": 1000000000000000019884624838656, // let bos_token = ""; @@ -68,12 +60,16 @@ impl TrOCR { let max_length = 1024; // TODO let eos_token_id = 2; let decoder_start_token_id = 2; - let n_kvs = match decoder.scale() { + let n_kvs = match &config.scale { Some(Scale::S) => 6, Some(Scale::B) => 12, _ => unimplemented!(), }; + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); + Ok(Self { encoder, decoder, @@ -83,11 +79,22 @@ impl TrOCR { eos_token_id, decoder_start_token_id, n_kvs, + batch, + width, + height, + processor, }) } + pub fn encode(&mut self, xs: &[Image]) -> Result { + let ys = self.processor.process_images(xs)?; + self.batch = xs.len(); // update + let ys = self.encoder.run(ys.into())?; + Ok(ys[0].to_owned()) + } + pub fn forward(&mut self, xs: &[Image]) -> Result> { - let encoder_hidden_states = elapsed!("encode", self.ts, { self.encoder.encode(xs)? }); + let encoder_hidden_states = elapsed!("encode", self.ts, { self.encode(xs)? }); let generated = elapsed!("generate", self.ts, { self.generate(&encoder_hidden_states)? }); @@ -100,10 +107,10 @@ impl TrOCR { // input_ids let input_ids = X::from(vec![self.decoder_start_token_id as f32]) .insert_axis(0)? - .repeat(0, self.encoder.batch())?; + .repeat(0, self.batch)?; // decoder - let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ + let mut decoder_outputs = self.decoder.run(Xs::from(vec![ input_ids.clone(), encoder_hidden_states.clone(), ]))?; @@ -116,9 +123,9 @@ impl TrOCR { .collect(); // token ids - let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; - let mut finished = vec![false; self.encoder.batch()]; - let mut last_tokens: Vec = vec![0.; self.encoder.batch()]; + let mut token_ids: Vec> = vec![vec![]; self.batch]; + let mut finished = vec![false; self.batch]; + let mut last_tokens: Vec = vec![0.; self.batch]; let logits_sampler = LogitsSampler::new(); // generate @@ -169,78 +176,15 @@ impl TrOCR { xs.push(X::ones(&[1])); // use_cache // generate - decoder_outputs = self.decoder_merged.inference(xs.into())?; + decoder_outputs = self.decoder_merged.run(xs.into())?; } Ok(token_ids) } - // fn generate(&mut self, encoder_hidden_states: &X) -> Result>> { - // // input_ids - // let input_ids = X::from(vec![self.decoder_start_token_id as f32]) - // .insert_axis(0)? - // .repeat(0, self.encoder.batch())?; - - // // decoder - // let mut decoder_outputs = self.decoder.inference(Xs::from(vec![ - // input_ids.clone(), - // encoder_hidden_states.clone(), - // ]))?; - - // // encoder kvs - // let encoder_kvs: Vec<_> = (3..4 * self.n_kvs) - // .step_by(4) - // .flat_map(|i| [i, i + 1]) - // .map(|i| decoder_outputs[i].clone()) - // .collect(); - - // // token ids - // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()]; - - // // generate - // for _ in 0..self.max_length { - // let logits = &decoder_outputs[0]; - // let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2) - // .step_by(4) - // .flat_map(|i| [i, i + 1]) - // .map(|i| decoder_outputs[i].clone()) - // .collect(); - - // // decode each token for each batch - // let (finished, last_tokens) = self.decoder_merged.processor().par_generate( - // logits, - // &mut token_ids, - // self.eos_token_id, - // )?; - - // if finished { - // break; - // } - - // // build inputs - // let input_ids = X::from(last_tokens).insert_axis(1)?; - // let mut xs = vec![input_ids, encoder_hidden_states.clone()]; - // for i in 0..self.n_kvs { - // xs.push(decoder_kvs[i * 2].clone()); - // xs.push(decoder_kvs[i * 2 + 1].clone()); - // xs.push(encoder_kvs[i * 2].clone()); - // xs.push(encoder_kvs[i * 2 + 1].clone()); - // } - // xs.push(X::ones(&[1])); // use_cache - - // // generate - // decoder_outputs = self.decoder_merged.inference(xs.into())?; - // } - - // Ok(token_ids) - // } - pub fn decode(&self, token_ids: Vec>) -> Result> { // decode - let texts = self - .decoder_merged - .processor() - .decode_tokens_batch(&token_ids, false)?; + let texts = self.processor.decode_tokens_batch(&token_ids, false)?; // to texts let texts = texts @@ -250,101 +194,8 @@ impl TrOCR { Ok(texts) } + + pub fn summary(&self) { + self.ts.summary(); + } } - -// #[derive(Debug, Builder)] -// pub struct TrOCREncoder { -// // TODO: `BaseVisualEncoder`, `BaseVisualModel` struct? -// engine: Engine, -// height: usize, -// width: usize, -// batch: usize, -// processor: Processor, -// } - -// impl TrOCREncoder { -// pub fn new(options: Options) -> Result { -// let engine = options.to_engine()?; -// let (batch, height, width) = ( -// engine.batch().opt(), -// engine.try_height().unwrap_or(&384.into()).opt(), -// engine.try_width().unwrap_or(&384.into()).opt(), -// ); -// let processor = options -// .to_processor()? -// .with_image_width(width as _) -// .with_image_height(height as _); - -// Ok(Self { -// engine, -// height, -// width, -// batch, -// processor, -// }) -// } - -// pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { -// self.batch = xs.len(); // TODO -// let x = self.processor.process_images(xs)?; - -// Ok(x.into()) -// } - -// pub fn inference(&mut self, xs: Xs) -> Result { -// self.engine.run(xs) -// } - -// fn encode(&mut self, xs: &[DynamicImage]) -> Result { -// // encode a batch of images into one embedding, that's `X` -// let xs = self.preprocess(xs)?; -// let xs = self.inference(xs)?; -// let x = xs[0].to_owned(); - -// Ok(x) -// } -// } - -// #[derive(Debug, Builder)] -// pub struct TrOCRDecoder { -// engine: Engine, -// batch: usize, -// } - -// impl TrOCRDecoder { -// pub fn new(options: Options) -> Result { -// let engine = options.to_engine()?; -// let batch = engine.batch().opt(); - -// Ok(Self { engine, batch }) -// } - -// pub fn inference(&mut self, xs: Xs) -> Result { -// self.engine.run(xs) -// } -// } - -// #[derive(Debug, Builder)] -// pub struct TrOCRDecoderMerged { -// engine: Engine, -// batch: usize, -// processor: Processor, -// } - -// impl TrOCRDecoderMerged { -// pub fn new(options: Options) -> Result { -// let engine = options.to_engine()?; -// let batch = engine.batch().opt(); -// let processor = options.to_processor()?; - -// Ok(Self { -// engine, -// batch, -// processor, -// }) -// } - -// pub fn inference(&mut self, xs: Xs) -> Result { -// self.engine.run(xs) -// } -// } diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs index 8034827..7941d03 100644 --- a/src/models/yolo/config.rs +++ b/src/models/yolo/config.rs @@ -1,234 +1,134 @@ use crate::{ - models::YOLOPredsFormat, Options, ResizeMode, Scale, Task, NAMES_COCO_KEYPOINTS_17, - NAMES_YOLO_DOCLAYOUT_10, + models::YOLOPredsFormat, ModelConfig, ResizeMode, Scale, Task, NAMES_COCO_80, + NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, NAMES_YOLO_DOCLAYOUT_10, }; -impl Options { +impl ModelConfig { pub fn yolo() -> Self { Self::default() - .with_model_name("yolo") + .with_name("yolo") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 1, 3.into()) .with_model_ixx(0, 2, 640.into()) .with_model_ixx(0, 3, 640.into()) .with_resize_mode(ResizeMode::FitAdaptive) .with_resize_filter("CatmullRom") - .with_find_contours(true) - } - - pub fn doclayout_yolo_docstructbench() -> Self { - Self::yolo_v10() - .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1 - .with_model_ixx(0, 2, (640, 1024, 1024).into()) - .with_model_ixx(0, 3, (640, 1024, 1024).into()) - .with_class_confs(&[0.4]) - .with_class_names(&NAMES_YOLO_DOCLAYOUT_10) + .with_class_names(&NAMES_COCO_80) } pub fn yolo_classify() -> Self { Self::yolo() - .with_model_task(Task::ImageClassification) + .with_task(Task::ImageClassification) .with_model_ixx(0, 2, 224.into()) .with_model_ixx(0, 3, 224.into()) .with_resize_mode(ResizeMode::FitExact) .with_resize_filter("Bilinear") + .with_class_names(&NAMES_IMAGENET_1K) } pub fn yolo_detect() -> Self { - Self::yolo().with_model_task(Task::ObjectDetection) + Self::yolo().with_task(Task::ObjectDetection) } pub fn yolo_pose() -> Self { Self::yolo() - .with_model_task(Task::KeypointsDetection) + .with_task(Task::KeypointsDetection) .with_keypoint_names(&NAMES_COCO_KEYPOINTS_17) } pub fn yolo_segment() -> Self { - Self::yolo().with_model_task(Task::InstanceSegmentation) + Self::yolo().with_task(Task::InstanceSegmentation) } pub fn yolo_obb() -> Self { - Self::yolo().with_model_task(Task::OrientedObjectDetection) + Self::yolo().with_task(Task::OrientedObjectDetection) } - pub fn fastsam_s() -> Self { - Self::yolo_segment() - .with_model_scale(Scale::S) - .with_model_version(8.into()) - .with_model_file("FastSAM-s.onnx") + pub fn auto_yolo_model_file(mut self) -> Self { + if self.model.file.is_empty() { + // [version]-[scale]-[task] + let mut y = String::new(); + if let Some(x) = self.version() { + y.push_str(&x.to_string()); + } + if let Some(x) = self.scale() { + y.push_str(&format!("-{}", x)); + } + if let Some(x) = self.task() { + y.push_str(&format!("-{}", x.yolo_str())); + } + y.push_str(".onnx"); + self.model.file = y; + } + + self } - pub fn yolo_v8_rtdetr() -> Self { - Self::yolo() - .with_model_version(7.into()) - .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) - } - - pub fn yolo_v8_rtdetr_l() -> Self { - Self::yolo_v8_rtdetr() - .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) - .with_model_scale(Scale::L) - .with_model_file("rtdetr-l-det.onnx") - } - - pub fn yolo_v8_rtdetr_x() -> Self { - Self::yolo_v8_rtdetr() - .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) - .with_model_scale(Scale::X) - } - - pub fn yolo_n() -> Self { - Self::yolo().with_model_scale(Scale::N) - } - - pub fn yolo_s() -> Self { - Self::yolo().with_model_scale(Scale::S) - } - - pub fn yolo_m() -> Self { - Self::yolo().with_model_scale(Scale::M) - } - - pub fn yolo_l() -> Self { - Self::yolo().with_model_scale(Scale::L) - } - - pub fn yolo_x() -> Self { - Self::yolo().with_model_scale(Scale::X) - } - - pub fn yolo_v5() -> Self { - Self::yolo().with_model_version(5.into()) - } - - pub fn yolo_v6() -> Self { - Self::yolo().with_model_version(6.into()) - } - - pub fn yolo_v7() -> Self { - Self::yolo().with_model_version(7.into()) - } - - pub fn yolo_v8() -> Self { - Self::yolo().with_model_version(8.into()) - } - - pub fn yolo_v9() -> Self { - Self::yolo().with_model_version(9.into()) - } - - pub fn yolo_v10() -> Self { - Self::yolo().with_model_version(10.into()) - } - - pub fn yolo_v11() -> Self { - Self::yolo().with_model_version(11.into()) - } - - pub fn yolo_v12() -> Self { - Self::yolo().with_model_version(12.into()) - } - - pub fn yolo_v8_n() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::N) - } - - pub fn yolo_v8_s() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::S) - } - - pub fn yolo_v8_m() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::M) - } - - pub fn yolo_v8_l() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::L) - } - - pub fn yolo_v8_x() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::X) - } - - pub fn yolo_v11_n() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::N) - } - - pub fn yolo_v11_s() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::S) - } - - pub fn yolo_v11_m() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::M) - } - - pub fn yolo_v11_l() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::L) - } - - pub fn yolo_v11_x() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::X) + pub fn doclayout_yolo_docstructbench() -> Self { + Self::yolo_detect() + .with_version(10.into()) + .with_model_ixx(0, 2, (640, 1024, 1024).into()) + .with_model_ixx(0, 3, (640, 1024, 1024).into()) + .with_class_confs(&[0.4]) + .with_class_names(&NAMES_YOLO_DOCLAYOUT_10) + .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1 } + // YOLOE models pub fn yoloe_v8s_seg_pf() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::S) + Self::yolo_segment() + .with_version(8.into()) + .with_scale(Scale::S) .with_model_file("yoloe-v8s-seg-pf.onnx") } pub fn yoloe_v8m_seg_pf() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::M) + Self::yolo_segment() + .with_version(8.into()) + .with_scale(Scale::M) .with_model_file("yoloe-v8m-seg-pf.onnx") } pub fn yoloe_v8l_seg_pf() -> Self { - Self::yolo() - .with_model_version(8.into()) - .with_model_scale(Scale::L) + Self::yolo_segment() + .with_version(8.into()) + .with_scale(Scale::L) .with_model_file("yoloe-v8l-seg-pf.onnx") } pub fn yoloe_11s_seg_pf() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::S) + Self::yolo_segment() + .with_version(11.into()) + .with_scale(Scale::S) .with_model_file("yoloe-11s-seg-pf.onnx") } pub fn yoloe_11m_seg_pf() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::M) + Self::yolo_segment() + .with_version(11.into()) + .with_scale(Scale::M) .with_model_file("yoloe-v8m-seg-pf.onnx") } pub fn yoloe_11l_seg_pf() -> Self { - Self::yolo() - .with_model_version(11.into()) - .with_model_scale(Scale::L) + Self::yolo_segment() + .with_version(11.into()) + .with_scale(Scale::L) .with_model_file("yoloe-11l-seg-pf.onnx") } + + /// ---- TODO + pub fn fastsam_s() -> Self { + Self::yolo_segment() + .with_scale(Scale::S) + .with_version(8.into()) + .with_model_file("FastSAM-s.onnx") + } + + pub fn yolo_v8_rtdetr_l() -> Self { + Self::yolo_detect() + .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) + .with_scale(Scale::L) + .with_model_file("rtdetr-l-det.onnx") + } } diff --git a/src/models/yolo/impl.rs b/src/models/yolo/impl.rs index a15d902..7bf550d 100644 --- a/src/models/yolo/impl.rs +++ b/src/models/yolo/impl.rs @@ -8,8 +8,8 @@ use regex::Regex; use crate::{ elapsed, models::{BoxType, YOLOPredsFormat}, - DynConf, Engine, Hbb, Image, Keypoint, Mask, NmsOps, Obb, Ops, Options, Prob, Processor, Task, - Ts, Version, Xs, Y, + DynConf, Engine, Hbb, Image, Keypoint, Mask, ModelConfig, NmsOps, Obb, Ops, Prob, Processor, + Task, Ts, Version, Xs, Y, }; #[derive(Debug, Builder)] @@ -36,17 +36,18 @@ pub struct YOLO { topk: usize, } -impl TryFrom for YOLO { +impl TryFrom for YOLO { type Error = anyhow::Error; - fn try_from(options: Options) -> Result { - Self::new(options) + fn try_from(config: ModelConfig) -> Result { + Self::new(config) } } impl YOLO { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; + let (batch, height, width, ts, spec) = ( engine.batch().opt(), engine.try_height().unwrap_or(&640.into()).opt(), @@ -54,11 +55,8 @@ impl YOLO { engine.ts.clone(), engine.spec().to_owned(), ); - let processor = options - .to_processor()? - .with_image_width(width as _) - .with_image_height(height as _); - let task: Option = match &options.model_task { + + let task: Option = match &config.task { Some(task) => Some(task.clone()), None => match engine.try_fetch("task") { Some(x) => match x.as_str() { @@ -77,8 +75,8 @@ impl YOLO { }; // Task & layout - let version = options.model_version; - let (layout, task) = match &options.yolo_preds_format { + let version = config.version; + let (layout, task) = match &config.yolo_preds_format { // customized Some(layout) => { // check task @@ -172,7 +170,7 @@ impl YOLO { // Class names let names: Option> = match Self::fetch_names_from_onnx(&engine) { - Some(names_parsed) => match &options.class_names { + Some(names_parsed) => match &config.class_names { Some(names) => { if names.len() == names_parsed.len() { // prioritize user-defined @@ -188,25 +186,25 @@ impl YOLO { } None => Some(names_parsed), }, - None => options.class_names.clone(), + None => config.class_names.clone(), }; // Class names & Number of class - let (nc, names) = match (options.nc(), names) { + let (nc, names) = match (config.nc(), names) { (_, Some(names)) => (names.len(), names.to_vec()), (Some(nc), None) => (nc, Self::n2s(nc)), (None, None) => { anyhow::bail!( "Neither class names nor the number of classes were specified. \ - \nConsider specify them with `Options::default().with_nc()` or `Options::default().with_class_names()`" + \nConsider specify them with `ModelConfig::default().with_nc()` or `ModelConfig::default().with_class_names()`" ); } }; // Keypoint names & Number of keypoints let (nk, names_kpt) = if let Task::KeypointsDetection = task { - let nk = Self::fetch_nk_from_onnx(&engine).or(options.nk()); - match (&options.keypoint_names, nk) { + let nk = Self::fetch_nk_from_onnx(&engine).or(config.nk()); + match (&config.keypoint_names, nk) { (Some(names), Some(nk)) => { if names.len() != nk { anyhow::bail!( @@ -221,7 +219,7 @@ impl YOLO { (None, Some(nk)) => (nk, Self::n2s(nk)), (None, None) => anyhow::bail!( "Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \ - \nConsider specify them with `Options::default().with_nk()` or `Options::default().with_keypoint_names()`" + \nConsider specify them with `ModelConfig::default().with_nk()` or `ModelConfig::default().with_keypoint_names()`" ), } } else { @@ -229,12 +227,12 @@ impl YOLO { }; // Attributes - let topk = options.topk().unwrap_or(5); - let confs = DynConf::new(options.class_confs(), nc); - let kconfs = DynConf::new(options.keypoint_confs(), nk); - let iou = options.iou().unwrap_or(0.45); - let classes_excluded = options.classes_excluded().to_vec(); - let classes_retained = options.classes_retained().to_vec(); + let topk = config.topk().unwrap_or(5); + let confs = DynConf::new(config.class_confs(), nc); + let kconfs = DynConf::new(config.keypoint_confs(), nk); + let iou = config.iou().unwrap_or(0.45); + let classes_excluded = config.classes_excluded().to_vec(); + let classes_retained = config.classes_retained().to_vec(); let mut info = format!( "YOLO Version: {}, Task: {:?}, Category Count: {}, Keypoint Count: {}, TopK: {}", version.map_or("Unknown".into(), |x| x.to_string()), @@ -249,6 +247,10 @@ impl YOLO { if !classes_retained.is_empty() { info = format!("{}, classes_retained: {:?}", info, classes_retained); } + let processor = Processor::try_from_config(&config.processor)? + .with_image_width(width as _) + .with_image_height(height as _); + info!("{}", info); Ok(Self { diff --git a/src/models/yolop/config.rs b/src/models/yolop/config.rs index 6e1564e..9736c7d 100644 --- a/src/models/yolop/config.rs +++ b/src/models/yolop/config.rs @@ -1,14 +1,12 @@ /// Model configuration for `YOLOP` -impl crate::Options { +impl crate::ModelConfig { pub fn yolop() -> Self { Self::default() - .with_model_name("yolop") + .with_name("yolop") .with_model_ixx(0, 0, 1.into()) .with_model_ixx(0, 2, 640.into()) .with_model_ixx(0, 3, 640.into()) .with_resize_mode(crate::ResizeMode::FitAdaptive) - .with_resize_filter("Bilinear") - .with_normalize(true) .with_class_confs(&[0.3]) } diff --git a/src/models/yolop/impl.rs b/src/models/yolop/impl.rs index 1cbde83..b9a5c3f 100644 --- a/src/models/yolop/impl.rs +++ b/src/models/yolop/impl.rs @@ -3,7 +3,7 @@ use anyhow::Result; use ndarray::{s, Array, Axis, IxDyn}; use crate::{ - elapsed, DynConf, Engine, Hbb, Image, NmsOps, Ops, Options, Polygon, Processor, Ts, Xs, Y, + elapsed, DynConf, Engine, Hbb, Image, ModelConfig, NmsOps, Ops, Polygon, Processor, Ts, Xs, Y, }; #[derive(Builder, Debug)] @@ -20,8 +20,8 @@ pub struct YOLOPv2 { } impl YOLOPv2 { - pub fn new(options: Options) -> Result { - let engine = options.to_engine()?; + pub fn new(config: ModelConfig) -> Result { + let engine = Engine::try_from_config(&config.model)?; let spec = engine.spec().to_string(); let (batch, height, width, ts) = ( engine.batch().opt(), @@ -29,14 +29,13 @@ impl YOLOPv2 { engine.try_width().unwrap_or(&512.into()).opt(), engine.ts().clone(), ); - let processor = options - .to_processor()? + + let confs = DynConf::new(config.class_confs(), 80); + let iou = config.iou.unwrap_or(0.45f32); + let processor = Processor::try_from_config(&config.processor)? .with_image_width(width as _) .with_image_height(height as _); - let confs = DynConf::new(options.class_confs(), 80); - let iou = options.iou.unwrap_or(0.45f32); - Ok(Self { engine, height, diff --git a/src/utils/device.rs b/src/utils/device.rs index 8c612cd..a8099cf 100644 --- a/src/utils/device.rs +++ b/src/utils/device.rs @@ -1,16 +1,20 @@ #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum Device { - Auto(usize), Cpu(usize), Cuda(usize), TensorRT(usize), CoreML(usize), } +impl Default for Device { + fn default() -> Self { + Self::Cpu(0) + } +} + impl std::fmt::Display for Device { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let x = match self { - Self::Auto(i) => format!("auto:{}", i), Self::Cpu(i) => format!("cpu:{}", i), Self::Cuda(i) => format!("cuda:{}", i), Self::CoreML(i) => format!("mps:{}", i), @@ -47,7 +51,6 @@ impl TryFrom<&str> for Device { impl Device { pub fn id(&self) -> usize { match self { - Device::Auto(i) => *i, Device::Cpu(i) => *i, Device::Cuda(i) => *i, Device::TensorRT(i) => *i, diff --git a/src/utils/dtype.rs b/src/utils/dtype.rs index 538ae83..5874007 100644 --- a/src/utils/dtype.rs +++ b/src/utils/dtype.rs @@ -1,5 +1,6 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum DType { + #[default] Auto, Int4, Int8, diff --git a/src/utils/kind.rs b/src/utils/kind.rs deleted file mode 100644 index 4519427..0000000 --- a/src/utils/kind.rs +++ /dev/null @@ -1,18 +0,0 @@ -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub enum Kind { - // Do we really need this? - Vision, - Language, - VisionLanguage, -} - -impl std::fmt::Display for Kind { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let x = match self { - Self::Vision => "visual", - Self::Language => "textual", - Self::VisionLanguage => "vl", - }; - write!(f, "{}", x) - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 50b8068..008dfb6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,13 +2,12 @@ mod device; mod dtype; mod dynconf; mod iiix; -mod kind; mod logits_sampler; mod min_opt_max; mod names; mod ops; -mod options; mod processor; +mod processor_config; mod retry; mod scale; mod task; @@ -20,13 +19,12 @@ pub use device::Device; pub use dtype::DType; pub use dynconf::DynConf; pub(crate) use iiix::Iiix; -pub use kind::Kind; pub use logits_sampler::LogitsSampler; pub use min_opt_max::MinOptMax; pub use names::*; pub use ops::*; -pub use options::*; pub use processor::*; +pub use processor_config::ProcessorConfig; pub use scale::Scale; pub use task::Task; pub use traits::*; diff --git a/src/utils/ops.rs b/src/utils/ops.rs index 8d416ce..9f54162 100644 --- a/src/utils/ops.rs +++ b/src/utils/ops.rs @@ -267,12 +267,12 @@ impl Ops<'_> { PixelType::F32, )?; let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type()); - let (mut resizer, mut options) = Self::build_resizer_filter(filter)?; + let (mut resizer, mut config) = Self::build_resizer_filter(filter)?; if crop_src { let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _); - options = options.crop(0., 0., w.into(), h.into()); + config = config.crop(0., 0., w.into(), h.into()); }; - resizer.resize(&src, &mut dst, &options)?; + resizer.resize(&src, &mut dst, &config)?; // u8 -> f32 Self::u8_slice_to_f32(&dst.into_vec()) @@ -317,12 +317,12 @@ impl Ops<'_> { ) -> Result> { let src = Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), PixelType::U8)?; let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type()); - let (mut resizer, mut options) = Self::build_resizer_filter(filter)?; + let (mut resizer, mut config) = Self::build_resizer_filter(filter)?; if crop_src { let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _); - options = options.crop(0., 0., w.into(), h.into()); + config = config.crop(0., 0., w.into(), h.into()); }; - resizer.resize(&src, &mut dst, &options)?; + resizer.resize(&src, &mut dst, &config)?; Ok(dst.into_vec()) } @@ -348,13 +348,13 @@ impl Ops<'_> { th: u32, tw: u32, resizer: &mut Resizer, - options: &ResizeOptions, + config: &ResizeOptions, ) -> Result> { let buffer = if x.dimensions() == (tw, th) { x.to_rgb8().into_raw() } else { let mut dst = Image::new(tw, th, PixelType::U8x3); - resizer.resize(x, &mut dst, options)?; + resizer.resize(x, &mut dst, config)?; dst.into_vec() }; let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? @@ -370,9 +370,9 @@ impl Ops<'_> { filter: &str, ) -> Result> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); - let (mut resizer, options) = Self::build_resizer_filter(filter)?; + let (mut resizer, config) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { - let y = Self::resize_rgb(x, th, tw, &mut resizer, &options)?; + let y = Self::resize_rgb(x, th, tw, &mut resizer, &config)?; ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) @@ -387,7 +387,7 @@ impl Ops<'_> { resize_by: &str, center: bool, resizer: &mut Resizer, - options: &ResizeOptions, + config: &ResizeOptions, ) -> Result> { let (w0, h0) = x.dimensions(); let buffer = if w0 == tw && h0 == th { @@ -403,7 +403,7 @@ impl Ops<'_> { } "height" => (th * w0 / h0, th), "width" => (tw, tw * h0 / w0), - _ => anyhow::bail!("ModelConfig for `letterbox`: width, height, auto"), + _ => anyhow::bail!("EngineConfig for `letterbox`: width, height, auto"), }; let mut dst = Image::from_vec_u8( @@ -422,7 +422,7 @@ impl Ops<'_> { (0, 0) }; let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; - resizer.resize(x, &mut dst_cropped, options)?; + resizer.resize(x, &mut dst_cropped, config)?; dst.into_vec() }; let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)? @@ -441,9 +441,9 @@ impl Ops<'_> { center: bool, ) -> Result> { let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); - let (mut resizer, options) = Self::build_resizer_filter(filter)?; + let (mut resizer, config) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { - let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &options)?; + let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &config)?; ys.slice_mut(s![idx, .., .., ..]).assign(&y); } Ok(ys) diff --git a/src/utils/options.rs b/src/utils/options.rs deleted file mode 100644 index 613a620..0000000 --- a/src/utils/options.rs +++ /dev/null @@ -1,488 +0,0 @@ -//! Options for everthing -use aksr::Builder; -use anyhow::Result; -use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; - -use crate::{ - models::{SamKind, YOLOPredsFormat}, - try_fetch_file_stem, DType, Device, Engine, Hub, Iiix, Kind, LogitsSampler, MinOptMax, - Processor, ResizeMode, Scale, Task, Version, -}; - -/// Options for building models and inference -#[derive(Builder, Debug, Clone)] -pub struct Options { - // Model configs - pub model_file: String, - pub model_name: &'static str, - pub model_device: Device, - pub model_dtype: DType, - pub model_version: Option, - pub model_task: Option, - pub model_scale: Option, - pub model_kind: Option, - pub model_iiixs: Vec, - pub model_spec: String, - pub model_num_dry_run: usize, - pub trt_fp16: bool, - pub profile: bool, - - // models - pub model_encoder_file: Option, - pub model_decoder_file: Option, - pub visual_encoder_file: Option, - pub visual_decoder_file: Option, - pub textual_encoder_file: Option, - pub textual_decoder_file: Option, - - // Processor configs - #[args(except(setter))] - pub image_width: u32, - #[args(except(setter))] - pub image_height: u32, - pub resize_mode: ResizeMode, - pub resize_filter: &'static str, - pub padding_value: u8, - pub letterbox_center: bool, - pub normalize: bool, - pub image_std: Vec, - pub image_mean: Vec, - pub nchw: bool, - pub unsigned: bool, - - // Names - pub class_names: Option>, - pub class_names_2: Option>, - pub class_names_3: Option>, - pub keypoint_names: Option>, - pub keypoint_names_2: Option>, - pub keypoint_names_3: Option>, - pub text_names: Option>, - pub text_names_2: Option>, - pub text_names_3: Option>, - pub category_names: Option>, - pub category_names_2: Option>, - pub category_names_3: Option>, - - // Confs - pub class_confs: Vec, - pub class_confs_2: Vec, - pub class_confs_3: Vec, - pub keypoint_confs: Vec, - pub keypoint_confs_2: Vec, - pub keypoint_confs_3: Vec, - pub text_confs: Vec, - pub text_confs_2: Vec, - pub text_confs_3: Vec, - - // Files - pub file: Option, - pub file_2: Option, - pub file_3: Option, - - // For classification - pub apply_softmax: Option, - pub topk: Option, - pub topk_2: Option, - pub topk_3: Option, - - // For detection - #[args(aka = "nc")] - pub num_classes: Option, - #[args(aka = "nk")] - pub num_keypoints: Option, - #[args(aka = "nm")] - pub num_masks: Option, - pub iou: Option, - pub iou_2: Option, - pub iou_3: Option, - pub apply_nms: Option, - pub find_contours: bool, - pub yolo_preds_format: Option, - pub classes_excluded: Vec, - pub classes_retained: Vec, - pub min_width: Option, - pub min_height: Option, - - // Language models related - pub model_max_length: Option, - pub tokenizer_file: Option, - pub config_file: Option, - pub special_tokens_map_file: Option, - pub tokenizer_config_file: Option, - pub generation_config_file: Option, - pub vocab_file: Option, // vocab.json file - pub vocab_txt: Option, // vacab.txt file, not kv pairs - pub temperature: f32, - pub topp: f32, - - // For DB - pub unclip_ratio: Option, - pub binary_thresh: Option, - - // For SAM - pub sam_kind: Option, // TODO: remove - pub low_res_mask: Option, // TODO: remove - - // Others - pub ort_graph_opt_level: Option, -} - -impl Default for Options { - fn default() -> Self { - Self { - model_file: Default::default(), - model_name: Default::default(), - model_version: Default::default(), - model_task: Default::default(), - model_scale: Default::default(), - model_kind: Default::default(), - model_device: Device::Cpu(0), - model_dtype: DType::Auto, - model_spec: Default::default(), - model_iiixs: Default::default(), - model_num_dry_run: 3, - trt_fp16: true, - profile: false, - normalize: true, - image_mean: vec![], - image_std: vec![], - image_height: 640, - image_width: 640, - padding_value: 114, - resize_mode: ResizeMode::FitExact, - resize_filter: "Bilinear", - letterbox_center: false, - nchw: true, - unsigned: false, - class_names: None, - class_names_2: None, - class_names_3: None, - category_names: None, - category_names_2: None, - category_names_3: None, - keypoint_names: None, - keypoint_names_2: None, - keypoint_names_3: None, - text_names: None, - text_names_2: None, - text_names_3: None, - file: None, - file_2: None, - file_3: None, - class_confs: vec![0.3f32], - class_confs_2: vec![0.3f32], - class_confs_3: vec![0.3f32], - keypoint_confs: vec![0.3f32], - keypoint_confs_2: vec![0.5f32], - keypoint_confs_3: vec![0.5f32], - text_confs: vec![0.4f32], - text_confs_2: vec![0.4f32], - text_confs_3: vec![0.4f32], - apply_softmax: Some(false), - num_classes: None, - num_keypoints: None, - num_masks: None, - iou: None, - iou_2: None, - iou_3: None, - find_contours: false, - yolo_preds_format: None, - classes_excluded: vec![], - classes_retained: vec![], - apply_nms: None, - model_max_length: None, - tokenizer_file: None, - config_file: None, - special_tokens_map_file: None, - tokenizer_config_file: None, - generation_config_file: None, - vocab_file: None, - vocab_txt: None, - min_width: None, - min_height: None, - unclip_ratio: Some(1.5), - binary_thresh: Some(0.2), - sam_kind: None, - low_res_mask: None, - temperature: 1., - topp: 0., - topk: None, - topk_2: None, - topk_3: None, - ort_graph_opt_level: None, - model_encoder_file: None, - model_decoder_file: None, - visual_encoder_file: None, - visual_decoder_file: None, - textual_encoder_file: None, - textual_decoder_file: None, - } - } -} - -impl Options { - pub fn new() -> Self { - Default::default() - } - - pub fn to_engine(&self) -> Result { - Engine { - file: self.model_file.clone(), - spec: self.model_spec.clone(), - device: self.model_device, - trt_fp16: self.trt_fp16, - iiixs: self.model_iiixs.clone(), - num_dry_run: self.model_num_dry_run, - graph_opt_level: self.ort_graph_opt_level, - ..Default::default() - } - .build() - } - - pub fn to_processor(&self) -> Result { - let logits_sampler = LogitsSampler::new() - .with_temperature(self.temperature) - .with_topp(self.topp); - - // try to build tokenizer - let tokenizer = match self.model_kind { - Some(Kind::Language) | Some(Kind::VisionLanguage) => Some(self.try_build_tokenizer()?), - _ => None, - }; - - // try to build vocab from `vocab.txt` - let vocab: Vec = match &self.vocab_txt { - Some(x) => { - let file = if !std::path::PathBuf::from(&x).exists() { - Hub::default().try_fetch(&format!("{}/{}", self.model_name, x))? - } else { - x.to_string() - }; - std::fs::read_to_string(file)? - .lines() - .map(|line| line.to_string()) - .collect() - } - None => vec![], - }; - - Ok(Processor { - image_width: self.image_width, - image_height: self.image_height, - resize_mode: self.resize_mode.clone(), - resize_filter: self.resize_filter, - padding_value: self.padding_value, - do_normalize: self.normalize, - image_mean: self.image_mean.clone(), - image_std: self.image_std.clone(), - nchw: self.nchw, - unsigned: self.unsigned, - tokenizer, - vocab, - logits_sampler: Some(logits_sampler), - ..Default::default() - }) - } - - pub fn commit(mut self) -> Result { - // Identify the local model or fetch the remote model - - if std::path::PathBuf::from(&self.model_file).exists() { - // Local - self.model_spec = format!( - "{}/{}", - self.model_name, - try_fetch_file_stem(&self.model_file)? - ); - } else { - // Remote - if self.model_file.is_empty() && self.model_name.is_empty() { - anyhow::bail!("Neither `model_name` nor `model_file` were specified. Faild to fetch model from remote.") - } - - // Load - match Hub::is_valid_github_release_url(&self.model_file) { - Some((owner, repo, tag, _file_name)) => { - let stem = try_fetch_file_stem(&self.model_file)?; - self.model_spec = - format!("{}/{}-{}-{}-{}", self.model_name, owner, repo, tag, stem); - self.model_file = Hub::default().try_fetch(&self.model_file)?; - } - None => { - // special yolo case - if self.model_file.is_empty() && self.model_name == "yolo" { - // [version]-[scale]-[task] - let mut y = String::new(); - if let Some(x) = self.model_version() { - y.push_str(&x.to_string()); - } - if let Some(x) = self.model_scale() { - y.push_str(&format!("-{}", x)); - } - if let Some(x) = self.model_task() { - y.push_str(&format!("-{}", x.yolo_str())); - } - y.push_str(".onnx"); - self.model_file = y; - } - - // append dtype to model file - match self.model_dtype { - d @ (DType::Auto | DType::Fp32) => { - if self.model_file.is_empty() { - self.model_file = format!("{}.onnx", d); - } - } - dtype => { - if self.model_file.is_empty() { - self.model_file = format!("{}.onnx", dtype); - } else { - let pos = self.model_file.len() - 5; // .onnx - let suffix = self.model_file.split_off(pos); - self.model_file = - format!("{}-{}{}", self.model_file, dtype, suffix); - } - } - } - - let stem = try_fetch_file_stem(&self.model_file)?; - self.model_spec = format!("{}/{}", self.model_name, stem); - self.model_file = Hub::default() - .try_fetch(&format!("{}/{}", self.model_name, self.model_file))?; - } - } - } - - Ok(self) - } - - pub fn with_batch_size(mut self, x: usize) -> Self { - self.model_iiixs.push(Iiix::from((0, 0, x.into()))); - self - } - - pub fn with_image_height(mut self, x: u32) -> Self { - self.image_height = x; - self.model_iiixs.push(Iiix::from((0, 2, x.into()))); - self - } - - pub fn with_image_width(mut self, x: u32) -> Self { - self.image_width = x; - self.model_iiixs.push(Iiix::from((0, 3, x.into()))); - self - } - - pub fn with_model_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self { - self.model_iiixs.push(Iiix::from((i, ii, x))); - self - } - - pub fn exclude_classes(mut self, xs: &[usize]) -> Self { - self.classes_retained.clear(); - self.classes_excluded.extend_from_slice(xs); - self - } - - pub fn retain_classes(mut self, xs: &[usize]) -> Self { - self.classes_excluded.clear(); - self.classes_retained.extend_from_slice(xs); - self - } - - pub fn try_build_tokenizer(&self) -> Result { - let mut hub = Hub::default(); - // config file - // TODO: save configs? - let pad_id = match hub.try_fetch( - self.tokenizer_config_file - .as_ref() - .unwrap_or(&format!("{}/config.json", self.model_name)), - ) { - Ok(x) => { - let config: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(x)?)?; - config["pad_token_id"].as_u64().unwrap_or(0) as u32 - } - Err(_err) => 0u32, - }; - - // tokenizer_config file - let mut max_length = None; - let mut pad_token = String::from("[PAD]"); - match hub.try_fetch( - self.tokenizer_config_file - .as_ref() - .unwrap_or(&format!("{}/tokenizer_config.json", self.model_name)), - ) { - Err(_) => {} - Ok(x) => { - let tokenizer_config: serde_json::Value = - serde_json::from_str(&std::fs::read_to_string(x)?)?; - max_length = tokenizer_config["model_max_length"].as_u64(); - pad_token = tokenizer_config["pad_token"] - .as_str() - .unwrap_or("[PAD]") - .to_string(); - } - } - - // tokenizer file - let mut tokenizer: tokenizers::Tokenizer = tokenizers::Tokenizer::from_file( - hub.try_fetch( - self.tokenizer_file - .as_ref() - .unwrap_or(&format!("{}/tokenizer.json", self.model_name)), - )?, - ) - .map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?; - - // TODO: padding - // if `max_length` specified: use `Fixed` strategy - // else: use `BatchLongest` strategy - // TODO: if sequence_length is dynamic, `BatchLongest` is fine - let tokenizer = match self.model_max_length { - Some(n) => { - let n = match max_length { - None => n, - Some(x) => x.min(n), - }; - tokenizer - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::Fixed(n as _), - pad_token, - pad_id, - ..Default::default() - })) - .clone() - } - None => match max_length { - Some(n) => tokenizer - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .with_truncation(Some(TruncationParams { - max_length: n as _, - ..Default::default() - })) - .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))? - .clone(), - None => tokenizer - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .clone(), - }, - }; - - // TODO: generation_config.json & special_tokens_map file - - Ok(tokenizer.into()) - } -} diff --git a/src/utils/processor.rs b/src/utils/processor.rs index f74ee6a..54fe366 100644 --- a/src/utils/processor.rs +++ b/src/utils/processor.rs @@ -5,12 +5,12 @@ use rayon::prelude::*; use std::sync::Mutex; use tokenizers::{Encoding, Tokenizer}; -use crate::{Image, ImageTransformInfo, LogitsSampler, ResizeMode, X}; +use crate::{Hub, Image, ImageTransformInfo, LogitsSampler, ProcessorConfig, ResizeMode, X}; #[derive(Builder, Debug, Clone)] pub struct Processor { - pub image_width: u32, // target image width - pub image_height: u32, // target image height + pub image_width: u32, + pub image_height: u32, pub images_transform_info: Vec, pub resize_mode: ResizeMode, pub resize_filter: &'static str, @@ -47,12 +47,53 @@ impl Default for Processor { } impl Processor { + pub fn try_from_config(config: &ProcessorConfig) -> Result { + let logits_sampler = LogitsSampler::new() + .with_temperature(config.temperature) + .with_topp(config.topp); + + // try to build tokenizer + let tokenizer = config.try_build_tokenizer()?; + + // try to build vocab from `vocab.txt` + let vocab: Vec = match &config.vocab_txt { + Some(x) => { + let file = if !std::path::PathBuf::from(&x).exists() { + Hub::default().try_fetch(x)? + } else { + x.to_string() + }; + std::fs::read_to_string(file)? + .lines() + .map(|line| line.to_string()) + .collect() + } + None => vec![], + }; + + Ok(Processor { + image_width: config.image_width.unwrap_or_default(), + image_height: config.image_height.unwrap_or_default(), + resize_mode: config.resize_mode.clone(), + resize_filter: config.resize_filter.unwrap_or("Bilinear"), + padding_value: config.padding_value, + do_normalize: config.normalize, + image_mean: config.image_mean.clone(), + image_std: config.image_std.clone(), + nchw: config.nchw, + unsigned: config.unsigned, + tokenizer, + vocab, + logits_sampler: Some(logits_sampler), + ..Default::default() + }) + } + pub fn reset_image0_status(&mut self) { self.images_transform_info.clear(); } pub fn process_images(&mut self, xs: &[Image]) -> Result { - // self.reset_image0_status(); let (mut x, images_transform_info) = self.par_resize(xs)?; self.images_transform_info = images_transform_info; diff --git a/src/utils/processor_config.rs b/src/utils/processor_config.rs new file mode 100644 index 0000000..d6a3daf --- /dev/null +++ b/src/utils/processor_config.rs @@ -0,0 +1,245 @@ +use aksr::Builder; +use anyhow::Result; +use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; + +use crate::{Hub, ResizeMode}; + +#[derive(Builder, Debug, Clone)] +pub struct ProcessorConfig { + // Vision + pub image_width: Option, + pub image_height: Option, + pub resize_mode: ResizeMode, + pub resize_filter: Option<&'static str>, + pub padding_value: u8, + pub normalize: bool, + pub image_std: Vec, + pub image_mean: Vec, + pub nchw: bool, + pub unsigned: bool, + + // Text + pub model_max_length: Option, + pub tokenizer_file: Option, + pub config_file: Option, + pub special_tokens_map_file: Option, + pub tokenizer_config_file: Option, + pub generation_config_file: Option, + pub vocab_file: Option, + pub vocab_txt: Option, + pub temperature: f32, + pub topp: f32, +} + +impl Default for ProcessorConfig { + fn default() -> Self { + Self { + image_width: None, + image_height: None, + resize_mode: ResizeMode::FitExact, + resize_filter: Some("Bilinear"), + padding_value: 114, + normalize: true, + image_std: vec![], + image_mean: vec![], + nchw: true, + unsigned: false, + model_max_length: None, + tokenizer_file: None, + config_file: None, + special_tokens_map_file: None, + tokenizer_config_file: None, + generation_config_file: None, + vocab_file: None, + vocab_txt: None, + temperature: 1.0, + topp: 0.9, + } + } +} + +impl ProcessorConfig { + pub fn try_build_tokenizer(&self) -> Result> { + let mut hub = Hub::default(); + + // tokenizer file + let mut tokenizer: Tokenizer = match &self.tokenizer_file { + None => return Ok(None), + Some(file) => Tokenizer::from_file(hub.try_fetch(file)?) + .map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?, + }; + + // config file + // TODO: save configs? + let pad_id = match &self.tokenizer_config_file { + None => 0u32, + Some(file) => match hub.try_fetch(file) { + Ok(x) => { + let config: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(x)?)?; + config["pad_token_id"].as_u64().unwrap_or(0) as u32 + } + Err(_err) => 0u32, + }, + }; + + // tokenizer_config file + let mut max_length = None; + let mut pad_token = String::from("[PAD]"); + + if let Some(file) = &self.tokenizer_config_file { + match hub.try_fetch(file) { + Err(_) => {} + Ok(x) => { + let tokenizer_config: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(x)?)?; + max_length = tokenizer_config["model_max_length"].as_u64(); + pad_token = tokenizer_config["pad_token"] + .as_str() + .unwrap_or("[PAD]") + .to_string(); + } + } + } + + // TODO: padding + // if `max_length` specified: use `Fixed` strategy + // else: use `BatchLongest` strategy + // TODO: if sequence_length is dynamic, `BatchLongest` is fine + let tokenizer = match self.model_max_length { + Some(n) => { + let n = match max_length { + None => n, + Some(x) => x.min(n), + }; + tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::Fixed(n as _), + pad_token, + pad_id, + ..Default::default() + })) + .clone() + } + None => match max_length { + Some(n) => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: n as _, + ..Default::default() + })) + .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))? + .clone(), + None => tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .clone(), + }, + }; + + Ok(Some(tokenizer.into())) + } +} + +#[macro_export] +macro_rules! impl_process_config_methods { + ($ty:ty, $field:ident) => { + impl $ty { + pub fn with_image_width(mut self, image_width: u32) -> Self { + self.$field = self.$field.with_image_width(image_width); + self + } + pub fn with_image_height(mut self, image_height: u32) -> Self { + self.$field = self.$field.with_image_height(image_height); + self + } + pub fn with_resize_mode(mut self, resize_mode: $crate::ResizeMode) -> Self { + self.$field = self.$field.with_resize_mode(resize_mode); + self + } + pub fn with_resize_filter(mut self, resize_filter: &'static str) -> Self { + self.$field = self.$field.with_resize_filter(resize_filter); + self + } + pub fn with_padding_value(mut self, padding_value: u8) -> Self { + self.$field = self.$field.with_padding_value(padding_value); + self + } + pub fn with_normalize(mut self, normalize: bool) -> Self { + self.$field = self.$field.with_normalize(normalize); + self + } + pub fn with_image_std(mut self, image_std: &[f32]) -> Self { + self.$field = self.$field.with_image_std(image_std); + self + } + pub fn with_image_mean(mut self, image_mean: &[f32]) -> Self { + self.$field = self.$field.with_image_mean(image_mean); + self + } + pub fn with_nchw(mut self, nchw: bool) -> Self { + self.$field = self.$field.with_nchw(nchw); + self + } + pub fn with_unsigned(mut self, unsigned: bool) -> Self { + self.$field = self.$field.with_unsigned(unsigned); + self + } + pub fn with_model_max_length(mut self, model_max_length: u64) -> Self { + self.$field = self.$field.with_model_max_length(model_max_length); + self + } + pub fn with_tokenizer_file(mut self, tokenizer_file: &str) -> Self { + self.$field = self.$field.with_tokenizer_file(tokenizer_file); + self + } + pub fn with_config_file(mut self, config_file: &str) -> Self { + self.$field = self.$field.with_config_file(config_file); + self + } + pub fn with_special_tokens_map_file(mut self, special_tokens_map_file: &str) -> Self { + self.$field = self + .$field + .with_special_tokens_map_file(special_tokens_map_file); + self + } + pub fn with_tokenizer_config_file(mut self, tokenizer_config_file: &str) -> Self { + self.$field = self + .$field + .with_tokenizer_config_file(tokenizer_config_file); + self + } + pub fn with_generation_config_file(mut self, generation_config_file: &str) -> Self { + self.$field = self + .$field + .with_generation_config_file(generation_config_file); + self + } + pub fn with_vocab_file(mut self, vocab_file: &str) -> Self { + self.$field = self.$field.with_vocab_file(vocab_file); + self + } + pub fn with_vocab_txt(mut self, vocab_txt: &str) -> Self { + self.$field = self.$field.with_vocab_txt(vocab_txt); + self + } + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.$field = self.$field.with_temperature(temperature); + self + } + pub fn with_topp(mut self, topp: f32) -> Self { + self.$field = self.$field.with_topp(topp); + self + } + } + }; +} diff --git a/src/viz/annotator.rs b/src/viz/annotator.rs index c51b35e..6b8b045 100644 --- a/src/viz/annotator.rs +++ b/src/viz/annotator.rs @@ -4,7 +4,7 @@ use anyhow::Result; use crate::{DrawContext, Drawable, Image, Style, TextRenderer}; /// Annotator provides configuration for drawing annotations on images, -/// including styles, color palettes, and text rendering options. +/// including styles, color palettes, and text rendering config. #[derive(Clone, Builder)] pub struct Annotator { prob_style: Option