diff --git a/Cargo.toml b/Cargo.toml
index 1cd2b59..b4146ae 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "usls"
edition = "2021"
-version = "0.1.0-beta.1"
+version = "0.1.0-beta.2"
rust-version = "1.82"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@@ -44,6 +44,7 @@ ort = { version = "2.0.0-rc.9", default-features = false, optional = true , feat
"half"
]}
tokenizers = { version = "0.21.1" }
+paste = "1.0.15"
[build-dependencies]
prost-build = "0.13.5"
diff --git a/README.md b/README.md
index f03afd1..cb658fb 100644
--- a/README.md
+++ b/README.md
@@ -116,7 +116,8 @@
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
-| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation Answering | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
+| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation
Background Erase | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
+| [BEN2](https://huggingface.co/PramaLLC/BEN2) | Image Segmentation
Background Erase | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
diff --git a/examples/ben2/main.rs b/examples/ben2/main.rs
index 9096318..5d0375c 100644
--- a/examples/ben2/main.rs
+++ b/examples/ben2/main.rs
@@ -1,4 +1,4 @@
-use usls::{models::RMBG, Annotator, DataLoader, Options};
+use usls::{models::RMBG, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -20,11 +20,11 @@ fn main() -> anyhow::Result<()> {
let args: Args = argh::from_env();
// build model
- let options = Options::ben2_base()
+ let config = ModelConfig::ben2_base()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = RMBG::new(options)?;
+ let mut model = RMBG::new(config)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/cat.png"])?;
diff --git a/examples/blip/main.rs b/examples/blip/main.rs
index a04bfb8..6367a67 100644
--- a/examples/blip/main.rs
+++ b/examples/blip/main.rs
@@ -1,4 +1,4 @@
-use usls::{models::Blip, DataLoader, Options};
+use usls::{models::Blip, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// BLIP Example
@@ -20,13 +20,10 @@ fn main() -> anyhow::Result<()> {
let args: Args = argh::from_env();
// build model
- let options_visual = Options::blip_v1_base_caption_visual()
- .with_model_device(args.device.as_str().try_into()?)
+ let config = ModelConfig::blip_v1_base_caption()
+ .with_device_all(args.device.as_str().try_into()?)
.commit()?;
- let options_textual = Options::blip_v1_base_caption_textual()
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?;
- let mut model = Blip::new(options_visual, options_textual)?;
+ let mut model = Blip::new(config)?;
// image caption
let xs = DataLoader::try_read_n(&args.source)?;
diff --git a/examples/classifier/main.rs b/examples/classifier/main.rs
index fe9536e..73245ff 100644
--- a/examples/classifier/main.rs
+++ b/examples/classifier/main.rs
@@ -1,4 +1,4 @@
-use usls::{models::ImageClassifier, Annotator, DataLoader, Options};
+use usls::{models::ImageClassifier, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -36,20 +36,20 @@ fn main() -> anyhow::Result<()> {
let args: Args = argh::from_env();
// build model
- let options = match args.model.to_lowercase().as_str() {
- "beit" => Options::beit_base(),
- "convnext" => Options::convnext_v2_atto(),
- "deit" => Options::deit_tiny_distill(),
- "fastvit" => Options::fastvit_t8_distill(),
- "mobileone" => Options::mobileone_s0(),
+ let config = match args.model.to_lowercase().as_str() {
+ "beit" => ModelConfig::beit_base(),
+ "convnext" => ModelConfig::convnext_v2_atto(),
+ "deit" => ModelConfig::deit_tiny_distill(),
+ "fastvit" => ModelConfig::fastvit_t8_distill(),
+ "mobileone" => ModelConfig::mobileone_s0(),
_ => anyhow::bail!("Unsupported model: {}", args.model),
};
- let options = options
+ let config = config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = ImageClassifier::try_from(options)?;
+ let mut model = ImageClassifier::try_from(config)?;
// load images
let xs = DataLoader::try_read_n(&args.source)?;
diff --git a/examples/clip/main.rs b/examples/clip/main.rs
index a600913..c650d00 100644
--- a/examples/clip/main.rs
+++ b/examples/clip/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::Clip, DataLoader, Ops, Options};
+use usls::{models::Clip, DataLoader, ModelConfig, Ops};
#[derive(argh::FromArgs)]
/// CLIP Example
@@ -14,18 +14,13 @@ fn main() -> Result<()> {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
-
let args: Args = argh::from_env();
+
// build model
- let options_visual = Options::jina_clip_v1_visual()
- // clip_vit_b32_visual()
- .with_model_device(args.device.as_str().try_into()?)
+ let config = ModelConfig::jina_clip_v1()
+ .with_device_all(args.device.as_str().try_into()?)
.commit()?;
- let options_textual = Options::jina_clip_v1_textual()
- // clip_vit_b32_textual()
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?;
- let mut model = Clip::new(options_visual, options_textual)?;
+ let mut model = Clip::new(config)?;
// texts
let texts = vec![
diff --git a/examples/d-fine/main.rs b/examples/d-fine/main.rs
index b74e3dc..ffdfaca 100644
--- a/examples/d-fine/main.rs
+++ b/examples/d-fine/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::RTDETR, Annotator, DataLoader, Options};
+use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -7,9 +7,8 @@ fn main() -> Result<()> {
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
- // options
- let options = Options::d_fine_n_coco().commit()?;
- let mut model = RTDETR::new(options)?;
+ // config
+ let mut model = RTDETR::new(ModelConfig::d_fine_n_coco().commit()?)?;
// load
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/db/main.rs b/examples/db/main.rs
index e205c35..2f92e65 100644
--- a/examples/db/main.rs
+++ b/examples/db/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::DB, Annotator, DataLoader, Options, Style};
+use usls::{models::DB, Annotator, DataLoader, ModelConfig, Style};
#[derive(argh::FromArgs)]
/// Example
@@ -41,15 +41,13 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let options = match &args.model {
- Some(m) => Options::db().with_model_file(m),
- None => Options::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?),
- };
- let mut model = DB::new(
- options
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- )?;
+ let config = match &args.model {
+ Some(m) => ModelConfig::db().with_model_file(m),
+ None => ModelConfig::ppocr_det_v4_ch().with_model_dtype(args.dtype.as_str().try_into()?),
+ }
+ .with_device_all(args.device.as_str().try_into()?)
+ .commit()?;
+ let mut model = DB::new(config)?;
// load image
let xs = DataLoader::try_read_n(&[
diff --git a/examples/deim/main.rs b/examples/deim/main.rs
index b253394..6fc358f 100644
--- a/examples/deim/main.rs
+++ b/examples/deim/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::RTDETR, Annotator, DataLoader, Options};
+use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -7,9 +7,8 @@ fn main() -> Result<()> {
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
- // options
- let options = Options::deim_dfine_s_coco().commit()?;
- let mut model = RTDETR::new(options)?;
+ // config
+ let mut model = RTDETR::new(ModelConfig::deim_dfine_s_coco().commit()?)?;
// load
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs
index f5b1d3e..3981ce2 100644
--- a/examples/depth-anything/main.rs
+++ b/examples/depth-anything/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::DepthAnything, Annotator, DataLoader, Options, Style};
+use usls::{models::DepthAnything, Annotator, DataLoader, ModelConfig, Style};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -8,8 +8,7 @@ fn main() -> Result<()> {
.init();
// build model
- let options = Options::depth_anything_v2_small().commit()?;
- let mut model = DepthAnything::new(options)?;
+ let mut model = DepthAnything::new(ModelConfig::depth_anything_v2_small().commit()?)?;
// load
let xs = DataLoader::try_read_n(&["images/street.jpg"])?;
diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs
index 929fa53..8c0f559 100644
--- a/examples/depth-pro/main.rs
+++ b/examples/depth-pro/main.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use usls::DataLoader;
-use usls::{models::DepthPro, Annotator, Options, Style};
+use usls::{models::DepthPro, Annotator, ModelConfig, Style};
#[derive(argh::FromArgs)]
/// Example
@@ -23,11 +23,11 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// model
- let options = Options::depth_pro()
+ let config = ModelConfig::depth_pro()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = DepthPro::new(options)?;
+ let mut model = DepthPro::new(config)?;
// load
let xs = DataLoader::try_read_n(&["images/street.jpg"])?;
diff --git a/examples/dinov2/main.rs b/examples/dinov2/main.rs
index e39427b..f9eda41 100644
--- a/examples/dinov2/main.rs
+++ b/examples/dinov2/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::DINOv2, DataLoader, Options};
+use usls::{models::DINOv2, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -11,8 +11,10 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&["./assets/bus.jpg", "./assets/bus.jpg"])?;
// model
- let options = Options::dinov2_small().with_batch_size(xs.len()).commit()?;
- let mut model = DINOv2::new(options)?;
+ let config = ModelConfig::dinov2_small()
+ .with_batch_size_all(xs.len())
+ .commit()?;
+ let mut model = DINOv2::new(config)?;
// encode images
let y = model.encode_images(&xs)?;
diff --git a/examples/doclayout-yolo/main.rs b/examples/doclayout-yolo/main.rs
index b6ab0d8..3b2c467 100644
--- a/examples/doclayout-yolo/main.rs
+++ b/examples/doclayout-yolo/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::YOLO, Annotator, DataLoader, Options};
+use usls::{models::YOLO, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -18,7 +18,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let config = Options::doclayout_yolo_docstructbench()
+ let config = ModelConfig::doclayout_yolo_docstructbench()
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut model = YOLO::new(config)?;
diff --git a/examples/fast/main.rs b/examples/fast/main.rs
index 9304c9d..493f737 100644
--- a/examples/fast/main.rs
+++ b/examples/fast/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::DB, Annotator, DataLoader, Options, Scale, Style};
+use usls::{models::DB, Annotator, DataLoader, ModelConfig, Scale, Style};
#[derive(argh::FromArgs)]
/// Example
@@ -26,16 +26,16 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let options = match args.scale.as_str().try_into()? {
- Scale::T => Options::fast_tiny(),
- Scale::S => Options::fast_small(),
- Scale::B => Options::fast_base(),
+ let config = match args.scale.as_str().try_into()? {
+ Scale::T => ModelConfig::fast_tiny(),
+ Scale::S => ModelConfig::fast_small(),
+ Scale::B => ModelConfig::fast_base(),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale),
};
let mut model = DB::new(
- options
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
+ config
+ .with_dtype_all(args.dtype.as_str().try_into()?)
+ .with_device_all(args.device.as_str().try_into()?)
.commit()?,
)?;
diff --git a/examples/fastsam/main.rs b/examples/fastsam/main.rs
index 29e2c1e..47a54a2 100644
--- a/examples/fastsam/main.rs
+++ b/examples/fastsam/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::YOLO, Annotator, DataLoader, Options};
+use usls::{models::YOLO, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -22,7 +22,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let config = Options::fastsam_s()
+ let config = ModelConfig::fastsam_s()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
diff --git a/examples/florence2/README.md b/examples/florence2/README.md
index 6078515..377cb5d 100644
--- a/examples/florence2/README.md
+++ b/examples/florence2/README.md
@@ -1,7 +1,7 @@
## Quick Start
```shell
-cargo run -r -F cuda --example florence2 -- --device cuda --scale base --dtype fp16
+cargo run -r -F cuda --example florence2 -- --device cuda --dtype fp16
```
diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs
index 4d6bde0..63bf33c 100644
--- a/examples/florence2/main.rs
+++ b/examples/florence2/main.rs
@@ -1,20 +1,16 @@
use anyhow::Result;
-use usls::{models::Florence2, Annotator, DataLoader, Options, Scale, Style, Task};
+use usls::{models::Florence2, Annotator, DataLoader, ModelConfig, Style, Task};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
- #[argh(option, default = "String::from(\"auto\")")]
+ #[argh(option, default = "String::from(\"fp16\")")]
dtype: String,
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
-
- /// scale
- #[argh(option, default = "String::from(\"base\")")]
- scale: String,
}
fn main() -> Result<()> {
@@ -29,51 +25,12 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&["images/green-car.jpg", "assets/bus.jpg"])?;
// build model
- let (
- options_vision_encoder,
- options_text_embed,
- options_encoder,
- options_decoder,
- options_decoder_merged,
- ) = match args.scale.as_str().try_into()? {
- Scale::B => (
- Options::florence2_visual_encoder_base(),
- Options::florence2_textual_embed_base(),
- Options::florence2_texual_encoder_base(),
- Options::florence2_texual_decoder_base(),
- Options::florence2_texual_decoder_merged_base(),
- ),
- Scale::L => todo!(),
- _ => anyhow::bail!("Unsupported Florence2 scale."),
- };
-
- let mut model = Florence2::new(
- options_vision_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_text_embed
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_decoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_decoder_merged
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- )?;
+ let config = ModelConfig::florence2_base()
+ .with_dtype_all(args.dtype.as_str().try_into()?)
+ .with_device_all(args.device.as_str().try_into()?)
+ .with_batch_size_all(xs.len())
+ .commit()?;
+ let mut model = Florence2::new(config)?;
// tasks
let tasks = [
diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs
index e80083f..e155428 100644
--- a/examples/grounding-dino/main.rs
+++ b/examples/grounding-dino/main.rs
@@ -1,11 +1,11 @@
use anyhow::Result;
-use usls::{models::GroundingDINO, Annotator, DataLoader, Options};
+use usls::{models::GroundingDINO, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
- #[argh(option, default = "String::from(\"auto\")")]
+ #[argh(option, default = "String::from(\"fp16\")")]
dtype: String,
/// device
@@ -45,7 +45,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
- let options = Options::grounding_dino_tiny()
+ let config = ModelConfig::grounding_dino_tiny()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>())
@@ -53,7 +53,7 @@ fn main() -> Result<()> {
.with_text_confs(&[0.25])
.commit()?;
- let mut model = GroundingDINO::new(options)?;
+ let mut model = GroundingDINO::new(config)?;
// load images
let xs = DataLoader::try_read_n(&args.source)?;
diff --git a/examples/linknet/main.rs b/examples/linknet/main.rs
index 3dfd21a..f7c85ad 100644
--- a/examples/linknet/main.rs
+++ b/examples/linknet/main.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use usls::DataLoader;
-use usls::{models::DB, Annotator, Options, Scale, Style};
+use usls::{models::DB, Annotator, ModelConfig, Scale, Style};
#[derive(argh::FromArgs)]
/// Example
@@ -27,14 +27,14 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let options = match args.scale.as_str().try_into()? {
- Scale::T => Options::linknet_r18(),
- Scale::S => Options::linknet_r34(),
- Scale::B => Options::linknet_r50(),
+ let config = match args.scale.as_str().try_into()? {
+ Scale::T => ModelConfig::linknet_r18(),
+ Scale::S => ModelConfig::linknet_r34(),
+ Scale::B => ModelConfig::linknet_r50(),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale),
};
let mut model = DB::new(
- options
+ config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?,
diff --git a/examples/modnet/main.rs b/examples/modnet/main.rs
index b0ee231..9dde076 100644
--- a/examples/modnet/main.rs
+++ b/examples/modnet/main.rs
@@ -1,4 +1,4 @@
-use usls::{models::MODNet, Annotator, DataLoader, Options};
+use usls::{models::MODNet, Annotator, DataLoader, ModelConfig};
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
@@ -7,8 +7,7 @@ fn main() -> anyhow::Result<()> {
.init();
// build model
- let options = Options::modnet_photographic().commit()?;
- let mut model = MODNet::new(options)?;
+ let mut model = MODNet::new(ModelConfig::modnet_photographic().commit()?)?;
// load image
let xs = DataLoader::try_read_n(&["images/liuyifei.png"])?;
diff --git a/examples/moondream2/main.rs b/examples/moondream2/main.rs
index 963197a..907a9aa 100644
--- a/examples/moondream2/main.rs
+++ b/examples/moondream2/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::Moondream2, Annotator, DataLoader, Options, Scale, Task};
+use usls::{models::Moondream2, Annotator, DataLoader, ModelConfig, Scale, Task};
#[derive(argh::FromArgs)]
/// Example
@@ -39,81 +39,16 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let (
- options_vision_encoder,
- options_vision_projection,
- options_text_decoder,
- options_text_encoder,
- options_coord_decoder,
- options_coord_encoder,
- options_size_decoder,
- options_size_encoder,
- ) = match args.scale.as_str().try_into()? {
- Scale::Billion(2.) => (
- Options::moondream2_2b_vision_encoder(),
- Options::moondream2_2b_vision_projection(),
- Options::moondream2_2b_text_decoder(),
- Options::moondream2_2b_text_encoder(),
- Options::moondream2_2b_coord_decoder(),
- Options::moondream2_2b_coord_encoder(),
- Options::moondream2_2b_size_decoder(),
- Options::moondream2_2b_size_encoder(),
- ),
- Scale::Billion(0.5) => (
- Options::moondream2_0_5b_vision_encoder(),
- Options::moondream2_0_5b_vision_projection(),
- Options::moondream2_0_5b_text_decoder(),
- Options::moondream2_0_5b_text_encoder(),
- Options::moondream2_0_5b_coord_decoder(),
- Options::moondream2_0_5b_coord_encoder(),
- Options::moondream2_0_5b_size_decoder(),
- Options::moondream2_0_5b_size_encoder(),
- ),
+ let config = match args.scale.as_str().try_into()? {
+ Scale::Billion(0.5) => ModelConfig::moondream2_0_5b(),
+ Scale::Billion(2.) => ModelConfig::moondream2_2b(),
_ => unimplemented!(),
- };
+ }
+ .with_dtype_all(args.dtype.as_str().try_into()?)
+ .with_device_all(args.device.as_str().try_into()?)
+ .commit()?;
- let mut model = Moondream2::new(
- options_vision_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- options_vision_projection
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- options_text_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- options_text_decoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- Some(
- options_coord_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- ),
- Some(
- options_coord_decoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- ),
- Some(
- options_size_encoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- ),
- Some(
- options_size_decoder
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- ),
- )?;
+ let mut model = Moondream2::new(config)?;
// load images
let xs = DataLoader::try_read_n(&args.source)?;
@@ -142,13 +77,6 @@ fn main() -> Result<()> {
}
Task::OpenSetDetection(_) | Task::OpenSetKeypointsDetection(_) => {
println!("{:?}", ys);
- // let annotator = Annotator::default()
- // .with_bboxes_thickness(4)
- // .without_bboxes_conf(true)
- // .with_keypoints_radius(6)
- // .with_keypoints_name(true)
- // .with_saveout("moondream2");
- // annotator.annotate(&xs, &ys);
// annotate
let annotator = Annotator::default()
diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs
index f717e00..a0f9a08 100644
--- a/examples/owlv2/main.rs
+++ b/examples/owlv2/main.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use usls::DataLoader;
-use usls::{models::OWLv2, Annotator, Options};
+use usls::{models::OWLv2, Annotator, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -46,14 +46,14 @@ fn main() -> Result<()> {
.init();
let args: Args = argh::from_env();
- // options
- let options = Options::owlv2_base_ensemble()
+ // config
+ let config = ModelConfig::owlv2_base_ensemble()
// owlv2_base()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::>())
.commit()?;
- let mut model = OWLv2::new(options)?;
+ let mut model = OWLv2::new(config)?;
// load
let xs = DataLoader::try_read_n(&args.source)?;
diff --git a/examples/picodet-layout/main.rs b/examples/picodet-layout/main.rs
index baa9bf4..5034a54 100644
--- a/examples/picodet-layout/main.rs
+++ b/examples/picodet-layout/main.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use usls::DataLoader;
-use usls::{models::PicoDet, Annotator, Options};
+use usls::{models::PicoDet, Annotator, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -8,12 +8,11 @@ fn main() -> Result<()> {
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
- // options
- let options = Options::picodet_layout_1x()
- // picodet_l_layout_3cls()
- // picodet_l_layout_17cls()
- .commit()?;
- let mut model = PicoDet::new(options)?;
+ // config
+ let config = ModelConfig::picodet_layout_1x().commit()?;
+ // picodet_l_layout_3cls()
+ // picodet_l_layout_17cls()
+ let mut model = PicoDet::new(config)?;
// load
let xs = DataLoader::try_read_n(&["images/academic.jpg"])?;
diff --git a/examples/rfdetr/main.rs b/examples/rfdetr/main.rs
index e034e45..eb39ef8 100644
--- a/examples/rfdetr/main.rs
+++ b/examples/rfdetr/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::RFDETR, Annotator, DataLoader, Options};
+use usls::{models::RFDETR, Annotator, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -7,9 +7,8 @@ fn main() -> Result<()> {
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
- // options
- let options = Options::rfdetr_base().commit()?;
- let mut model = RFDETR::new(options)?;
+ // config
+ let mut model = RFDETR::new(ModelConfig::rfdetr_base().commit()?)?;
// load
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/rmbg/main.rs b/examples/rmbg/main.rs
index cb34be0..4fbdfbd 100644
--- a/examples/rmbg/main.rs
+++ b/examples/rmbg/main.rs
@@ -1,10 +1,10 @@
-use usls::{models::RMBG, Annotator, DataLoader, Options};
+use usls::{models::RMBG, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
- #[argh(option, default = "String::from(\"auto\")")]
+ #[argh(option, default = "String::from(\"fp16\")")]
dtype: String,
/// device
@@ -23,18 +23,18 @@ fn main() -> anyhow::Result<()> {
.init();
let args: Args = argh::from_env();
- let options = match args.ver {
- 1.4 => Options::rmbg1_4(),
- 2.0 => Options::rmbg2_0(),
+ let config = match args.ver {
+ 1.4 => ModelConfig::rmbg1_4(),
+ 2.0 => ModelConfig::rmbg2_0(),
_ => unreachable!("Unsupported version"),
};
// build model
- let options = options
+ let config = config
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = RMBG::new(options)?;
+ let mut model = RMBG::new(config)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/cat.png"])?;
diff --git a/examples/rtdetr/main.rs b/examples/rtdetr/main.rs
index 1d3aa5e..8a19510 100644
--- a/examples/rtdetr/main.rs
+++ b/examples/rtdetr/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::RTDETR, Annotator, DataLoader, Options};
+use usls::{models::RTDETR, Annotator, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -7,15 +7,14 @@ fn main() -> Result<()> {
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
- // options
- let options = Options::rtdetr_v2_s_coco()
- // rtdetr_v1_r18vd_coco()
- // rtdetr_v2_ms_coco()
- // rtdetr_v2_m_coco()
- // rtdetr_v2_l_coco()
- // rtdetr_v2_x_coco()
- .commit()?;
- let mut model = RTDETR::new(options)?;
+ // config
+ let config = ModelConfig::rtdetr_v2_s_coco().commit()?;
+ // rtdetr_v1_r18vd_coco()
+ // rtdetr_v2_ms_coco()
+ // rtdetr_v2_m_coco()
+ // rtdetr_v2_l_coco()
+ // rtdetr_v2_x_coco()
+ let mut model = RTDETR::new(config)?;
// load
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs
index 314e6fc..c2dde49 100644
--- a/examples/rtmo/main.rs
+++ b/examples/rtmo/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::RTMO, Annotator, DataLoader, Options, Style, SKELETON_COCO_19};
+use usls::{models::RTMO, Annotator, DataLoader, ModelConfig, Style, SKELETON_COCO_19};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -8,7 +8,7 @@ fn main() -> Result<()> {
.init();
// build model
- let mut model = RTMO::new(Options::rtmo_s().commit()?)?;
+ let mut model = RTMO::new(ModelConfig::rtmo_s().commit()?)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/sam/main.rs b/examples/sam/main.rs
index 21e4fcc..0da9d7f 100644
--- a/examples/sam/main.rs
+++ b/examples/sam/main.rs
@@ -1,7 +1,7 @@
use anyhow::Result;
use usls::{
models::{SamKind, SamPrompt, SAM},
- Annotator, DataLoader, Options, Scale,
+ Annotator, DataLoader, ModelConfig, Scale,
};
#[derive(argh::FromArgs)]
@@ -28,40 +28,22 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// Build model
- let (options_encoder, options_decoder) = match args.kind.as_str().try_into()? {
- SamKind::Sam => (
- Options::sam_v1_base_encoder(),
- Options::sam_v1_base_decoder(),
- ),
+ let config = match args.kind.as_str().try_into()? {
+ SamKind::Sam => ModelConfig::sam_v1_base(),
SamKind::Sam2 => match args.scale.as_str().try_into()? {
- Scale::T => (Options::sam2_tiny_encoder(), Options::sam2_tiny_decoder()),
- Scale::S => (Options::sam2_small_encoder(), Options::sam2_small_decoder()),
- Scale::B => (
- Options::sam2_base_plus_encoder(),
- Options::sam2_base_plus_decoder(),
- ),
+ Scale::T => ModelConfig::sam2_tiny(),
+ Scale::S => ModelConfig::sam2_small(),
+ Scale::B => ModelConfig::sam2_base_plus(),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t.", args.scale),
},
+ SamKind::MobileSam => ModelConfig::mobile_sam_tiny(),
+ SamKind::SamHq => ModelConfig::sam_hq_tiny(),
+ SamKind::EdgeSam => ModelConfig::edge_sam_3x(),
+ }
+ .with_device_all(args.device.as_str().try_into()?)
+ .commit()?;
- SamKind::MobileSam => (
- Options::mobile_sam_tiny_encoder(),
- Options::mobile_sam_tiny_decoder(),
- ),
- SamKind::SamHq => (
- Options::sam_hq_tiny_encoder(),
- Options::sam_hq_tiny_decoder(),
- ),
- SamKind::EdgeSam => (
- Options::edge_sam_3x_encoder(),
- Options::edge_sam_3x_decoder(),
- ),
- };
-
- let options_encoder = options_encoder
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?;
- let options_decoder = options_decoder.commit()?;
- let mut model = SAM::new(options_encoder, options_decoder)?;
+ let mut model = SAM::new(config)?;
// Load image
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
diff --git a/examples/sam2/main.rs b/examples/sam2/main.rs
index e8722a4..79469f9 100644
--- a/examples/sam2/main.rs
+++ b/examples/sam2/main.rs
@@ -1,7 +1,7 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2},
- Annotator, DataLoader, Options, Scale,
+ Annotator, DataLoader, ModelConfig, Scale,
};
#[derive(argh::FromArgs)]
@@ -25,33 +25,16 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// Build model
- let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
- Scale::T => (
- Options::sam2_1_tiny_encoder(),
- Options::sam2_1_tiny_decoder(),
- ),
- Scale::S => (
- Options::sam2_1_small_encoder(),
- Options::sam2_1_small_decoder(),
- ),
- Scale::B => (
- Options::sam2_1_base_plus_encoder(),
- Options::sam2_1_base_plus_decoder(),
- ),
- Scale::L => (
- Options::sam2_1_large_encoder(),
- Options::sam2_1_large_decoder(),
- ),
+ let config = match args.scale.as_str().try_into()? {
+ Scale::T => ModelConfig::sam2_1_tiny(),
+ Scale::S => ModelConfig::sam2_1_small(),
+ Scale::B => ModelConfig::sam2_1_base_plus(),
+ Scale::L => ModelConfig::sam2_1_large(),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
- };
-
- let options_encoder = options_encoder
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?;
- let options_decoder = options_decoder
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?;
- let mut model = SAM2::new(options_encoder, options_decoder)?;
+ }
+ .with_device_all(args.device.as_str().try_into()?)
+ .commit()?;
+ let mut model = SAM2::new(config)?;
// Load image
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs
index a209e27..dbc2c23 100644
--- a/examples/sapiens/main.rs
+++ b/examples/sapiens/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::Sapiens, Annotator, DataLoader, Options};
+use usls::{models::Sapiens, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -17,10 +17,10 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build
- let options = Options::sapiens_seg_0_3b()
+ let config = ModelConfig::sapiens_seg_0_3b()
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = Sapiens::new(options)?;
+ let mut model = Sapiens::new(config)?;
// load
let xs = DataLoader::try_read_n(&["images/paul-george.jpg"])?;
diff --git a/examples/slanet/main.rs b/examples/slanet/main.rs
index 362671c..fe66322 100644
--- a/examples/slanet/main.rs
+++ b/examples/slanet/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::SLANet, Annotator, Color, DataLoader, Options};
+use usls::{models::SLANet, Annotator, Color, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -26,11 +26,11 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let options = Options::slanet_lcnet_v2_mobile_ch()
+ let config = ModelConfig::slanet_lcnet_v2_mobile_ch()
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.commit()?;
- let mut model = SLANet::new(options)?;
+ let mut model = SLANet::new(config)?;
// load
let xs = DataLoader::try_read_n(&[args.source])?;
diff --git a/examples/smolvlm/main.rs b/examples/smolvlm/main.rs
index 3bc87fb..7069282 100644
--- a/examples/smolvlm/main.rs
+++ b/examples/smolvlm/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::SmolVLM, DataLoader, Options, Scale};
+use usls::{models::SmolVLM, DataLoader, ModelConfig, Scale};
#[derive(argh::FromArgs)]
/// Example
@@ -29,32 +29,15 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let (options_vision_encoder, options_text_embed, options_decode) =
- match args.scale.as_str().try_into()? {
- Scale::Million(256.) => (
- Options::smolvlm_vision_256m(),
- Options::smolvlm_text_embed_256m(),
- Options::smolvlm_decoder_256m(),
- ),
- Scale::Million(500.) => (
- Options::smolvlm_vision_500m(),
- Options::smolvlm_text_embed_500m(),
- Options::smolvlm_decoder_500m(),
- ),
- _ => unimplemented!(),
- };
+ let config = match args.scale.as_str().try_into()? {
+ Scale::Million(256.) => ModelConfig::smolvlm_256m(),
+ Scale::Million(500.) => ModelConfig::smolvlm_500m(),
+ _ => unimplemented!(),
+ }
+ .with_device_all(args.device.as_str().try_into()?)
+ .commit()?;
- let mut model = SmolVLM::new(
- options_vision_encoder
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- options_text_embed
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- options_decode
- .with_model_device(args.device.as_str().try_into()?)
- .commit()?,
- )?;
+ let mut model = SmolVLM::new(config)?;
// load images
let xs = DataLoader::try_read_n(&args.source)?;
diff --git a/examples/svtr/main.rs b/examples/svtr/main.rs
index bd92036..f39c464 100644
--- a/examples/svtr/main.rs
+++ b/examples/svtr/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::SVTR, DataLoader, Options};
+use usls::{models::SVTR, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -22,13 +22,13 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let options = Options::ppocr_rec_v4_ch()
+ let config = ModelConfig::ppocr_rec_v4_ch()
// ppocr_rec_v4_en()
// repsvtr_ch()
.with_model_device(args.device.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.commit()?;
- let mut model = SVTR::new(options)?;
+ let mut model = SVTR::new(config)?;
// load images
let dl = DataLoader::new("./examples/svtr/images")?
diff --git a/examples/trocr/main.rs b/examples/trocr/main.rs
index f64b72e..79aad83 100644
--- a/examples/trocr/main.rs
+++ b/examples/trocr/main.rs
@@ -1,6 +1,6 @@
use usls::{
models::{TrOCR, TrOCRKind},
- DataLoader, Options, Scale,
+ DataLoader, ModelConfig, Scale,
};
#[derive(argh::FromArgs)]
@@ -38,52 +38,22 @@ fn main() -> anyhow::Result<()> {
])?;
// build model
- let (options_encoder, options_decoder, options_decoder_merged) =
- match args.scale.as_str().try_into()? {
- Scale::S => match args.kind.as_str().try_into()? {
- TrOCRKind::Printed => (
- Options::trocr_encoder_small_printed(),
- Options::trocr_decoder_small_printed(),
- Options::trocr_decoder_merged_small_printed(),
- ),
- TrOCRKind::HandWritten => (
- Options::trocr_encoder_small_handwritten(),
- Options::trocr_decoder_small_handwritten(),
- Options::trocr_decoder_merged_small_handwritten(),
- ),
- },
- Scale::B => match args.kind.as_str().try_into()? {
- TrOCRKind::Printed => (
- Options::trocr_encoder_base_printed(),
- Options::trocr_decoder_base_printed(),
- Options::trocr_decoder_merged_base_printed(),
- ),
- TrOCRKind::HandWritten => (
- Options::trocr_encoder_base_handwritten(),
- Options::trocr_decoder_base_handwritten(),
- Options::trocr_decoder_merged_base_handwritten(),
- ),
- },
- x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x),
- };
+ let config = match args.scale.as_str().try_into()? {
+ Scale::S => match args.kind.as_str().try_into()? {
+ TrOCRKind::Printed => ModelConfig::trocr_small_printed(),
+ TrOCRKind::HandWritten => ModelConfig::trocr_small_handwritten(),
+ },
+ Scale::B => match args.kind.as_str().try_into()? {
+ TrOCRKind::Printed => ModelConfig::trocr_base_printed(),
+ TrOCRKind::HandWritten => ModelConfig::trocr_base_handwritten(),
+ },
+ x => anyhow::bail!("Unsupported TrOCR scale: {:?}", x),
+ }
+ .with_device_all(args.device.as_str().try_into()?)
+ .with_dtype_all(args.dtype.as_str().try_into()?)
+ .commit()?;
- let mut model = TrOCR::new(
- options_encoder
- .with_model_device(args.device.as_str().try_into()?)
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_decoder
- .with_model_device(args.device.as_str().try_into()?)
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- options_decoder_merged
- .with_model_device(args.device.as_str().try_into()?)
- .with_model_dtype(args.dtype.as_str().try_into()?)
- .with_batch_size(xs.len())
- .commit()?,
- )?;
+ let mut model = TrOCR::new(config)?;
// inference
let ys = model.forward(&xs)?;
diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs
index 8b60d87..ed3ca1c 100644
--- a/examples/yolo-sam2/main.rs
+++ b/examples/yolo-sam2/main.rs
@@ -1,7 +1,7 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2, YOLO},
- Annotator, DataLoader, Options, Scale, Style,
+ Annotator, DataLoader, ModelConfig, Scale, Style,
};
#[derive(argh::FromArgs)]
@@ -21,17 +21,14 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build SAM
- let (options_encoder, options_decoder) = (
- Options::sam2_1_tiny_encoder().commit()?,
- Options::sam2_1_tiny_decoder().commit()?,
- );
- let mut sam = SAM2::new(options_encoder, options_decoder)?;
+ let mut sam = SAM2::new(ModelConfig::sam2_1_tiny().commit()?)?;
// build YOLOv8
- let options_yolo = Options::yolo_detect()
- .with_model_scale(Scale::N)
- .with_model_version(8.into())
+ let options_yolo = ModelConfig::yolo_detect()
+ .with_scale(Scale::N)
+ .with_version(8.into())
.with_model_device(args.device.as_str().try_into()?)
+ .auto_yolo_model_file()
.commit()?;
let mut yolo = YOLO::new(options_yolo)?;
diff --git a/examples/yolo/README.md b/examples/yolo/README.md
index d53104e..d3d214f 100644
--- a/examples/yolo/README.md
+++ b/examples/yolo/README.md
@@ -54,7 +54,7 @@ cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 -
cargo run -r --example yolo -- --ver 11 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv11-Obb
```
-**`cargo run -r --example yolo -- --help` for more options**
+**`cargo run -r --example yolo -- --help` for more config**
## Other YOLOv8 Solution Models
diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs
index 86111bd..4f69ad9 100644
--- a/examples/yolo/main.rs
+++ b/examples/yolo/main.rs
@@ -1,7 +1,7 @@
use anyhow::Result;
use usls::{
- models::YOLO, Annotator, DataLoader, Options, Style, NAMES_COCO_80, NAMES_COCO_KEYPOINTS_17,
- NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19,
+ models::YOLO, Annotator, DataLoader, ModelConfig, Style, NAMES_COCO_80,
+ NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, SKELETON_COCO_19, SKELETON_COLOR_COCO_19,
};
#[derive(argh::FromArgs, Debug)]
@@ -132,14 +132,15 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
- let mut options = Options::yolo()
+ let mut config = ModelConfig::yolo()
.with_model_file(&args.model.unwrap_or_default())
- .with_model_task(args.task.as_str().try_into()?)
- .with_model_version(args.ver.try_into()?)
- .with_model_scale(args.scale.as_str().try_into()?)
+ .with_task(args.task.as_str().try_into()?)
+ .with_version(args.ver.try_into()?)
+ .with_scale(args.scale.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
- .with_trt_fp16(args.trt_fp16)
+ // .with_trt_fp16(args.trt_fp16)
+ .with_model_trt_fp16(args.trt_fp16)
.with_model_ixx(
0,
0,
@@ -175,27 +176,27 @@ fn main() -> Result<()> {
.exclude_classes(&args.exclude_classes);
if args.use_coco_80_classes {
- options = options.with_class_names(&NAMES_COCO_80);
+ config = config.with_class_names(&NAMES_COCO_80);
}
if args.use_coco_17_keypoints_classes {
- options = options.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17);
+ config = config.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17);
}
if args.use_imagenet_1k_classes {
- options = options.with_class_names(&NAMES_IMAGENET_1K);
+ config = config.with_class_names(&NAMES_IMAGENET_1K);
}
if let Some(nc) = args.num_classes {
- options = options.with_nc(nc);
+ config = config.with_nc(nc);
}
if let Some(nk) = args.num_keypoints {
- options = options.with_nk(nk);
+ config = config.with_nk(nk);
}
if !args.class_names.is_empty() {
- options = options.with_class_names(
+ config = config.with_class_names(
&args
.class_names
.iter()
@@ -205,7 +206,7 @@ fn main() -> Result<()> {
}
if !args.keypoint_names.is_empty() {
- options = options.with_keypoint_names(
+ config = config.with_keypoint_names(
&args
.keypoint_names
.iter()
@@ -215,7 +216,7 @@ fn main() -> Result<()> {
}
// build model
- let mut model = YOLO::try_from(options.commit()?)?;
+ let mut model = YOLO::try_from(config.auto_yolo_model_file().commit()?)?;
// build dataloader
let dl = DataLoader::new(&args.source)?
diff --git a/examples/yoloe/README.md b/examples/yoloe/README.md
index 110c5a4..005ff33 100644
--- a/examples/yoloe/README.md
+++ b/examples/yoloe/README.md
@@ -1,6 +1,6 @@
## Quick Start
```shell
-cargo run -r --example yoloe
+cargo run -r -F cuda --example yoloe -- --device cuda
```
diff --git a/examples/yoloe/main.rs b/examples/yoloe/main.rs
index 841945c..175cbb2 100644
--- a/examples/yoloe/main.rs
+++ b/examples/yoloe/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::YOLO, Annotator, DataLoader, Options, Style};
+use usls::{models::YOLO, Annotator, DataLoader, ModelConfig, Style};
#[derive(argh::FromArgs)]
/// Example
@@ -21,8 +21,8 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
- // options
- let options = Options::yoloe_v8s_seg_pf()
+ // config
+ let config = ModelConfig::yoloe_v8s_seg_pf()
// yoloe_v8m_seg_pf()
// yoloe_v8l_seg_pf()
// yoloe_11s_seg_pf()
@@ -31,7 +31,7 @@ fn main() -> Result<()> {
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
- let mut model = YOLO::new(options)?;
+ let mut model = YOLO::new(config)?;
// load
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
diff --git a/examples/yolop/main.rs b/examples/yolop/main.rs
index 7f062db..767926b 100644
--- a/examples/yolop/main.rs
+++ b/examples/yolop/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::YOLOPv2, Annotator, DataLoader, Options};
+use usls::{models::YOLOPv2, Annotator, DataLoader, ModelConfig};
fn main() -> Result<()> {
tracing_subscriber::fmt()
@@ -8,8 +8,7 @@ fn main() -> Result<()> {
.init();
// build model
- let options = Options::yolop_v2_480x800().commit()?;
- let mut model = YOLOPv2::new(options)?;
+ let mut model = YOLOPv2::new(ModelConfig::yolop_v2_480x800().commit()?)?;
// load image
let xs = DataLoader::try_read_n(&["images/car-view.jpg"])?;
diff --git a/examples/yolov8-rtdetr/main.rs b/examples/yolov8-rtdetr/main.rs
index 3b96e8e..b085f28 100644
--- a/examples/yolov8-rtdetr/main.rs
+++ b/examples/yolov8-rtdetr/main.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use usls::{models::YOLO, Annotator, DataLoader, Options};
+use usls::{models::YOLO, Annotator, DataLoader, ModelConfig};
#[derive(argh::FromArgs)]
/// Example
@@ -22,7 +22,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env();
// build model
- let config = Options::yolo_v8_rtdetr_l()
+ let config = ModelConfig::yolo_v8_rtdetr_l()
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
diff --git a/src/inference/engine.rs b/src/inference/engine.rs
index ad6cdf0..cd2a9f3 100644
--- a/src/inference/engine.rs
+++ b/src/inference/engine.rs
@@ -13,8 +13,8 @@ use prost::Message;
use std::collections::HashSet;
use crate::{
- build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, Iiix, MinOptMax, Ops, Ts,
- Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X,
+ build_progress_bar, elapsed, human_bytes_binary, onnx, DType, Device, EngineConfig, Iiix,
+ MinOptMax, Ops, Ts, Xs, PROGRESS_BAR_STYLE_CYAN_2, PROGRESS_BAR_STYLE_FINISH, X,
};
impl From for DType {
@@ -93,6 +93,20 @@ impl Default for Engine {
}
impl Engine {
+ pub fn try_from_config(config: &EngineConfig) -> Result {
+ Self {
+ file: config.file.clone(),
+ spec: config.spec.clone(),
+ iiixs: config.iiixs.clone(),
+ device: config.device,
+ trt_fp16: config.trt_fp16,
+ num_dry_run: config.num_dry_run,
+ graph_opt_level: config.ort_graph_opt_level,
+ ..Default::default()
+ }
+ .build()
+ }
+
pub fn build(mut self) -> Result {
let name = format!("[{}] ort_initialization", self.spec);
elapsed!(&name, self.ts, {
diff --git a/src/inference/engine_config.rs b/src/inference/engine_config.rs
new file mode 100644
index 0000000..c63f251
--- /dev/null
+++ b/src/inference/engine_config.rs
@@ -0,0 +1,112 @@
+use aksr::Builder;
+use anyhow::Result;
+
+use crate::{try_fetch_file_stem, DType, Device, Hub, Iiix, MinOptMax};
+
+#[derive(Builder, Debug, Clone, Default)]
+pub struct EngineConfig {
+ pub file: String,
+ pub device: Device,
+ pub iiixs: Vec,
+ pub num_dry_run: usize,
+ pub trt_fp16: bool,
+ pub ort_graph_opt_level: Option,
+ pub spec: String, // TODO: move out
+ pub dtype: DType, // For dynamically loading the model
+}
+
+impl EngineConfig {
+ pub fn try_commit(mut self, name: &str) -> Result {
+ // Identify the local model or fetch the remote model
+ if std::path::PathBuf::from(&self.file).exists() {
+ // Local
+ self.spec = format!("{}/{}", name, try_fetch_file_stem(&self.file)?);
+ } else {
+ if self.file.is_empty() && name.is_empty() {
+ anyhow::bail!(
+ "Failed to commit model. Invalid model config: neither `name` nor `file` were specified. Failed to fetch model from Hub."
+ )
+ }
+
+ // Remote
+ match Hub::is_valid_github_release_url(&self.file) {
+ Some((owner, repo, tag, _file_name)) => {
+ let stem = try_fetch_file_stem(&self.file)?;
+ self.spec = format!("{}/{}-{}-{}-{}", name, owner, repo, tag, stem);
+ self.file = Hub::default().try_fetch(&self.file)?;
+ }
+ None => {
+ // append dtype to model file
+ match self.dtype {
+ d @ (DType::Auto | DType::Fp32) => {
+ if self.file.is_empty() {
+ self.file = format!("{}.onnx", d);
+ }
+ }
+ dtype => {
+ if self.file.is_empty() {
+ self.file = format!("{}.onnx", dtype);
+ } else {
+ let pos = self.file.len() - 5; // .onnx
+ let suffix = self.file.split_off(pos);
+ self.file = format!("{}-{}{}", self.file, dtype, suffix);
+ }
+ }
+ }
+
+ let stem = try_fetch_file_stem(&self.file)?;
+ self.spec = format!("{}/{}", name, stem);
+ self.file = Hub::default().try_fetch(&format!("{}/{}", name, self.file))?;
+ }
+ }
+ }
+
+ Ok(self)
+ }
+}
+
+impl EngineConfig {
+ pub fn with_ixx(mut self, i: usize, ii: usize, x: MinOptMax) -> Self {
+ self.iiixs.push(Iiix::from((i, ii, x)));
+ self
+ }
+
+ pub fn with_batch_size(mut self, x: MinOptMax) -> Self {
+ self.iiixs.push(Iiix::from((0, 0, x)));
+ self
+ }
+}
+
+#[macro_export]
+macro_rules! impl_model_config_methods {
+ ($ty:ty, $field:ident) => {
+ impl $ty {
+ paste::paste! {
+ pub fn [](mut self, file: &str) -> Self {
+ self.$field = self.$field.with_file(file);
+ self
+ }
+ pub fn [](mut self, dtype: $crate::DType) -> Self {
+ self.$field = self.$field.with_dtype(dtype);
+ self
+ }
+ pub fn [](mut self, device: $crate::Device) -> Self {
+ self.$field = self.$field.with_device(device);
+ self
+ }
+ pub fn [](mut self, x: bool) -> Self {
+ self.$field = self.$field.with_trt_fp16(x);
+ self
+ }
+ pub fn [](mut self, x: usize) -> Self {
+ self.$field = self.$field.with_num_dry_run(x);
+ self
+ }
+ pub fn [](mut self, i: usize, ii: usize, x: $crate::MinOptMax) -> Self {
+ self.$field = self.$field.with_ixx(i, ii, x);
+ self
+ }
+ }
+ }
+ };
+}
diff --git a/src/inference/image.rs b/src/inference/image.rs
index 50c4778..23b113f 100644
--- a/src/inference/image.rs
+++ b/src/inference/image.rs
@@ -308,12 +308,12 @@ impl Image {
));
}
- let (mut resizer, options) = build_resizer_filter(filter)?;
+ let (mut resizer, config) = build_resizer_filter(filter)?;
let x: DynamicImage = self.to_dyn();
if let ResizeMode::FitExact = mode {
let mut dst = FImage::new(tw, th, PixelType::U8x3);
- resizer.resize(&x, &mut dst, &options)?;
+ resizer.resize(&x, &mut dst, &config)?;
trans_info = trans_info
.with_height_scale(th as f32 / h0 as f32)
.with_width_scale(tw as f32 / w0 as f32);
@@ -362,7 +362,7 @@ impl Image {
};
let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?;
- resizer.resize(&x, &mut dst_cropped, &options)?;
+ resizer.resize(&x, &mut dst_cropped, &config)?;
Ok((Self::from_u8s(&dst.into_vec(), tw, th)?, trans_info))
}
diff --git a/src/inference/mod.rs b/src/inference/mod.rs
index 1cb35bd..8e5783c 100644
--- a/src/inference/mod.rs
+++ b/src/inference/mod.rs
@@ -1,10 +1,12 @@
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
mod engine;
+mod engine_config;
mod hbb;
mod image;
mod instance_meta;
mod keypoint;
mod mask;
+mod model_config;
mod obb;
mod polygon;
mod prob;
@@ -20,11 +22,13 @@ pub(crate) mod onnx {
#[cfg(any(feature = "ort-download-binaries", feature = "ort-load-dynamic"))]
pub use engine::*;
+pub use engine_config::EngineConfig;
pub use hbb::*;
pub use image::*;
pub use instance_meta::*;
pub use keypoint::*;
pub use mask::*;
+pub use model_config::*;
pub use obb::*;
pub use polygon::*;
pub use prob::*;
diff --git a/src/inference/model_config.rs b/src/inference/model_config.rs
new file mode 100644
index 0000000..6870cfe
--- /dev/null
+++ b/src/inference/model_config.rs
@@ -0,0 +1,243 @@
+use aksr::Builder;
+
+use crate::{
+ impl_model_config_methods, impl_process_config_methods,
+ models::{SamKind, YOLOPredsFormat},
+ EngineConfig, ProcessorConfig, Scale, Task, Version,
+};
+
+/// ModelConfig for building models and inference
+#[derive(Builder, Debug, Clone)]
+pub struct ModelConfig {
+ // Basics
+ pub name: &'static str,
+ pub version: Option,
+ pub task: Option,
+ pub scale: Option,
+
+ // Engines
+ pub model: EngineConfig,
+ pub visual: EngineConfig,
+ pub textual: EngineConfig,
+ pub encoder: EngineConfig,
+ pub decoder: EngineConfig,
+ pub visual_encoder: EngineConfig,
+ pub textual_encoder: EngineConfig,
+ pub visual_decoder: EngineConfig,
+ pub textual_decoder: EngineConfig,
+ pub textual_decoder_merged: EngineConfig,
+ pub size_encoder: EngineConfig,
+ pub size_decoder: EngineConfig,
+ pub coord_encoder: EngineConfig,
+ pub coord_decoder: EngineConfig,
+ pub visual_projection: EngineConfig,
+ pub textual_projection: EngineConfig,
+
+ // Processor
+ pub processor: ProcessorConfig,
+
+ // Others
+ pub class_names: Option>, // TODO: remove Option
+ pub keypoint_names: Option>, // TODO: remove Option
+ pub text_names: Option>, // TODO: remove Option
+ pub class_confs: Vec,
+ pub keypoint_confs: Vec,
+ pub text_confs: Vec,
+ pub apply_softmax: Option,
+ pub topk: Option,
+ #[args(aka = "nc")]
+ pub num_classes: Option,
+ #[args(aka = "nk")]
+ pub num_keypoints: Option,
+ #[args(aka = "nm")]
+ pub num_masks: Option,
+ pub iou: Option,
+ pub apply_nms: Option,
+ pub find_contours: bool,
+ pub yolo_preds_format: Option,
+ pub classes_excluded: Vec,
+ pub classes_retained: Vec,
+ pub min_width: Option,
+ pub min_height: Option,
+ pub db_unclip_ratio: Option,
+ pub db_binary_thresh: Option,
+ pub sam_kind: Option,
+ pub sam_low_res_mask: Option,
+}
+
+impl Default for ModelConfig {
+ fn default() -> Self {
+ Self {
+ class_names: None,
+ keypoint_names: None,
+ text_names: None,
+ class_confs: vec![0.25f32],
+ keypoint_confs: vec![0.3f32],
+ text_confs: vec![0.25f32],
+ apply_softmax: Some(false),
+ num_classes: None,
+ num_keypoints: None,
+ num_masks: None,
+ iou: None,
+ find_contours: false,
+ yolo_preds_format: None,
+ classes_excluded: vec![],
+ classes_retained: vec![],
+ apply_nms: None,
+ min_width: None,
+ min_height: None,
+ db_unclip_ratio: Some(1.5),
+ db_binary_thresh: Some(0.2),
+ sam_kind: None,
+ sam_low_res_mask: None,
+ topk: None,
+ model: Default::default(),
+ encoder: Default::default(),
+ decoder: Default::default(),
+ visual: Default::default(),
+ textual: Default::default(),
+ visual_encoder: Default::default(),
+ textual_encoder: Default::default(),
+ visual_decoder: Default::default(),
+ textual_decoder: Default::default(),
+ textual_decoder_merged: Default::default(),
+ processor: ProcessorConfig::default(),
+ size_encoder: Default::default(),
+ size_decoder: Default::default(),
+ coord_encoder: Default::default(),
+ coord_decoder: Default::default(),
+ visual_projection: Default::default(),
+ textual_projection: Default::default(),
+ version: None,
+ task: None,
+ scale: None,
+ name: Default::default(),
+ }
+ }
+}
+
+impl ModelConfig {
+ pub fn exclude_classes(mut self, xs: &[usize]) -> Self {
+ self.classes_retained.clear();
+ self.classes_excluded.extend_from_slice(xs);
+ self
+ }
+
+ pub fn retain_classes(mut self, xs: &[usize]) -> Self {
+ self.classes_excluded.clear();
+ self.classes_retained.extend_from_slice(xs);
+ self
+ }
+
+ pub fn commit(mut self) -> anyhow::Result {
+ fn try_commit(name: &str, mut m: EngineConfig) -> anyhow::Result {
+ if !m.file.is_empty() {
+ m = m.try_commit(name)?;
+ return Ok(m);
+ }
+ Ok(m)
+ }
+
+ self.model = try_commit(self.name, self.model)?;
+ self.visual = try_commit(self.name, self.visual)?;
+ self.textual = try_commit(self.name, self.textual)?;
+ self.encoder = try_commit(self.name, self.encoder)?;
+ self.decoder = try_commit(self.name, self.decoder)?;
+ self.visual_encoder = try_commit(self.name, self.visual_encoder)?;
+ self.textual_encoder = try_commit(self.name, self.textual_encoder)?;
+ self.visual_decoder = try_commit(self.name, self.visual_decoder)?;
+ self.textual_decoder = try_commit(self.name, self.textual_decoder)?;
+ self.textual_decoder_merged = try_commit(self.name, self.textual_decoder_merged)?;
+ self.size_encoder = try_commit(self.name, self.size_encoder)?;
+ self.size_decoder = try_commit(self.name, self.size_decoder)?;
+ self.coord_encoder = try_commit(self.name, self.coord_encoder)?;
+ self.coord_decoder = try_commit(self.name, self.coord_decoder)?;
+ self.visual_projection = try_commit(self.name, self.visual_projection)?;
+ self.textual_projection = try_commit(self.name, self.textual_projection)?;
+
+ Ok(self)
+ }
+
+ pub fn with_batch_size_all(mut self, batch_size: usize) -> Self {
+ self.visual = self.visual.with_ixx(0, 0, batch_size.into());
+ self.textual = self.textual.with_ixx(0, 0, batch_size.into());
+ self.model = self.model.with_ixx(0, 0, batch_size.into());
+ self.encoder = self.encoder.with_ixx(0, 0, batch_size.into());
+ self.decoder = self.decoder.with_ixx(0, 0, batch_size.into());
+ self.visual_encoder = self.visual_encoder.with_ixx(0, 0, batch_size.into());
+ self.textual_encoder = self.textual_encoder.with_ixx(0, 0, batch_size.into());
+ self.visual_decoder = self.visual_decoder.with_ixx(0, 0, batch_size.into());
+ self.textual_decoder = self.textual_decoder.with_ixx(0, 0, batch_size.into());
+ self.textual_decoder_merged = self
+ .textual_decoder_merged
+ .with_ixx(0, 0, batch_size.into());
+ self.size_encoder = self.size_encoder.with_ixx(0, 0, batch_size.into());
+ self.size_decoder = self.size_decoder.with_ixx(0, 0, batch_size.into());
+ self.coord_encoder = self.coord_encoder.with_ixx(0, 0, batch_size.into());
+ self.coord_decoder = self.coord_decoder.with_ixx(0, 0, batch_size.into());
+ self.visual_projection = self.visual_projection.with_ixx(0, 0, batch_size.into());
+ self.textual_projection = self.textual_projection.with_ixx(0, 0, batch_size.into());
+
+ self
+ }
+
+ pub fn with_device_all(mut self, device: crate::Device) -> Self {
+ self.visual = self.visual.with_device(device);
+ self.textual = self.textual.with_device(device);
+ self.model = self.model.with_device(device);
+ self.encoder = self.encoder.with_device(device);
+ self.decoder = self.decoder.with_device(device);
+ self.visual_encoder = self.visual_encoder.with_device(device);
+ self.textual_encoder = self.textual_encoder.with_device(device);
+ self.visual_decoder = self.visual_decoder.with_device(device);
+ self.textual_decoder = self.textual_decoder.with_device(device);
+ self.textual_decoder_merged = self.textual_decoder_merged.with_device(device);
+ self.size_encoder = self.size_encoder.with_device(device);
+ self.size_decoder = self.size_decoder.with_device(device);
+ self.coord_encoder = self.coord_encoder.with_device(device);
+ self.coord_decoder = self.coord_decoder.with_device(device);
+ self.visual_projection = self.visual_projection.with_device(device);
+ self.textual_projection = self.textual_projection.with_device(device);
+
+ self
+ }
+
+ pub fn with_dtype_all(mut self, dtype: crate::DType) -> Self {
+ self.visual = self.visual.with_dtype(dtype);
+ self.textual = self.textual.with_dtype(dtype);
+ self.model = self.model.with_dtype(dtype);
+ self.encoder = self.encoder.with_dtype(dtype);
+ self.decoder = self.decoder.with_dtype(dtype);
+ self.visual_encoder = self.visual_encoder.with_dtype(dtype);
+ self.textual_encoder = self.textual_encoder.with_dtype(dtype);
+ self.visual_decoder = self.visual_decoder.with_dtype(dtype);
+ self.textual_decoder = self.textual_decoder.with_dtype(dtype);
+ self.textual_decoder_merged = self.textual_decoder_merged.with_dtype(dtype);
+ self.size_encoder = self.size_encoder.with_dtype(dtype);
+ self.size_decoder = self.size_decoder.with_dtype(dtype);
+ self.coord_encoder = self.coord_encoder.with_dtype(dtype);
+ self.coord_decoder = self.coord_decoder.with_dtype(dtype);
+ self.visual_projection = self.visual_projection.with_dtype(dtype);
+ self.textual_projection = self.textual_projection.with_dtype(dtype);
+
+ self
+ }
+}
+
+impl_model_config_methods!(ModelConfig, model);
+impl_model_config_methods!(ModelConfig, visual);
+impl_model_config_methods!(ModelConfig, textual);
+impl_model_config_methods!(ModelConfig, encoder);
+impl_model_config_methods!(ModelConfig, decoder);
+impl_model_config_methods!(ModelConfig, visual_encoder);
+impl_model_config_methods!(ModelConfig, textual_encoder);
+impl_model_config_methods!(ModelConfig, visual_decoder);
+impl_model_config_methods!(ModelConfig, textual_decoder);
+impl_model_config_methods!(ModelConfig, textual_decoder_merged);
+impl_model_config_methods!(ModelConfig, size_encoder);
+impl_model_config_methods!(ModelConfig, size_decoder);
+impl_model_config_methods!(ModelConfig, coord_encoder);
+impl_model_config_methods!(ModelConfig, coord_decoder);
+impl_model_config_methods!(ModelConfig, visual_projection);
+impl_model_config_methods!(ModelConfig, textual_projection);
+impl_process_config_methods!(ModelConfig, processor);
diff --git a/src/io/dataloader.rs b/src/io/dataloader.rs
index 420fdff..19aed4e 100644
--- a/src/io/dataloader.rs
+++ b/src/io/dataloader.rs
@@ -367,14 +367,14 @@ impl DataLoader {
fn load_image_paths_from_folder(source: &str, exts: &[&str]) -> Result> {
let source_path = Path::new(source);
let mut paths: Vec = Vec::new();
- let options = MatchOptions {
+ let config = MatchOptions {
case_sensitive: false,
require_literal_separator: false,
require_literal_leading_dot: false,
};
for ext in exts.iter() {
let pattern = source_path.join(format!("*.{}", ext));
- let paths_: Vec = glob_with(pattern.to_str().unwrap(), options)?
+ let paths_: Vec = glob_with(pattern.to_str().unwrap(), config)?
.filter_map(|entry| entry.ok())
.collect();
paths.extend(paths_);
@@ -393,12 +393,12 @@ impl DataLoader {
}
fn glob(pattern: &str, sort: bool, case_sensitive: bool) -> anyhow::Result> {
- let options = MatchOptions {
+ let config = MatchOptions {
case_sensitive,
require_literal_separator: false,
require_literal_leading_dot: false,
};
- let mut paths: Vec = glob_with(pattern, options)?
+ let mut paths: Vec = glob_with(pattern, config)?
.filter_map(|entry| entry.ok())
.collect();
@@ -479,7 +479,7 @@ impl DataLoader {
self
}
- pub fn with_batch_size(mut self, x: usize) -> Self {
+ pub fn with_batch_size_all(mut self, x: usize) -> Self {
self.batch_size = x;
self
}
diff --git a/src/models/beit/config.rs b/src/models/beit/config.rs
index 46636b6..34eb389 100644
--- a/src/models/beit/config.rs
+++ b/src/models/beit/config.rs
@@ -1,10 +1,8 @@
-use crate::NAMES_IMAGENET_1K;
-
/// Model configuration for `BEiT`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn beit() -> Self {
Self::default()
- .with_model_name("beit")
+ .with_name("beit")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
@@ -13,7 +11,7 @@ impl crate::Options {
.with_image_std(&[0.5, 0.5, 0.5])
.with_normalize(true)
.with_apply_softmax(true)
- .with_class_names(&NAMES_IMAGENET_1K)
+ .with_class_names(&crate::NAMES_IMAGENET_1K)
}
pub fn beit_base() -> Self {
diff --git a/src/models/ben2/config.rs b/src/models/ben2/config.rs
index b6b5632..f942025 100644
--- a/src/models/ben2/config.rs
+++ b/src/models/ben2/config.rs
@@ -1,5 +1,5 @@
/// Model configuration for `BEN2`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn ben2_base() -> Self {
Self::rmbg().with_model_file("ben2-base.onnx")
}
diff --git a/src/models/blip/config.rs b/src/models/blip/config.rs
index 038e046..0395d54 100644
--- a/src/models/blip/config.rs
+++ b/src/models/blip/config.rs
@@ -1,34 +1,24 @@
/// Model configuration for `BLIP`
-impl crate::Options {
- pub fn blip() -> Self {
- Self::default().with_model_name("blip").with_batch_size(1)
- }
-
+impl crate::ModelConfig {
#[allow(clippy::excessive_precision)]
- pub fn blip_visual() -> Self {
- Self::blip()
- .with_model_kind(crate::Kind::Vision)
- .with_model_ixx(0, 2, 384.into())
- .with_model_ixx(0, 3, 384.into())
+ pub fn blip() -> Self {
+ Self::default()
+ .with_name("blip")
+ .with_batch_size_all(1)
+ .with_visual_ixx(0, 1, 3.into())
+ .with_visual_ixx(0, 2, 384.into())
+ .with_visual_ixx(0, 3, 384.into())
.with_image_mean(&[0.48145466, 0.4578275, 0.40821073])
.with_image_std(&[0.26862954, 0.26130258, 0.27577711])
- .with_resize_filter("Bilinear")
- .with_normalize(true)
}
- pub fn blip_textual() -> Self {
- Self::blip().with_model_kind(crate::Kind::Language)
- }
-
- pub fn blip_v1_base_caption_visual() -> Self {
- Self::blip_visual()
- .with_model_version(1.into())
- .with_model_file("v1-base-caption-visual.onnx")
- }
-
- pub fn blip_v1_base_caption_textual() -> Self {
- Self::blip_textual()
- .with_model_version(1.into())
- .with_model_file("v1-base-caption-textual.onnx")
+ pub fn blip_v1_base_caption() -> Self {
+ Self::blip()
+ .with_version(1.into())
+ .with_visual_file("v1-base-caption-visual.onnx")
+ .with_textual_file("v1-base-caption-textual.onnx")
+ .with_tokenizer_file("blip/tokenizer.json")
+ .with_tokenizer_config_file("blip/tokenizer_config.json")
+ .with_special_tokens_map_file("blip/special_tokens_map.json")
}
}
diff --git a/src/models/blip/impl.rs b/src/models/blip/impl.rs
index 3719f8a..3a4e922 100644
--- a/src/models/blip/impl.rs
+++ b/src/models/blip/impl.rs
@@ -2,26 +2,34 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
-use crate::{
- elapsed,
- models::{BaseModelTextual, BaseModelVisual},
- Image, LogitsSampler, Options, Ts, Xs, X, Y,
-};
+use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Ts, Xs, X, Y};
#[derive(Debug, Builder)]
pub struct Blip {
- visual: BaseModelVisual,
- textual: BaseModelTextual,
- ts: Ts,
+ visual: Engine,
+ textual: Engine,
+ batch: usize,
+ height: usize,
+ width: usize,
+ processor: Processor,
max_length: usize,
eos_token_id: u32,
+ ts: Ts,
}
impl Blip {
- pub fn new(options_visual: Options, options_textual: Options) -> Result {
- let visual = BaseModelVisual::new(options_visual)?;
- let textual = BaseModelTextual::new(options_textual)?;
- let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]);
+ pub fn new(config: ModelConfig) -> Result {
+ let visual = Engine::try_from_config(&config.visual)?;
+ let textual = Engine::try_from_config(&config.textual)?;
+ let (batch, height, width) = (
+ visual.batch().opt(),
+ visual.try_height().unwrap_or(&384.into()).opt(),
+ visual.try_width().unwrap_or(&384.into()).opt(),
+ );
+ let ts = Ts::merge(&[visual.ts(), textual.ts()]);
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
let max_length = 512;
let eos_token_id = 102;
@@ -31,17 +39,24 @@ impl Blip {
ts,
max_length,
eos_token_id,
+ batch,
+ height,
+ width,
+ processor,
})
}
pub fn encode_images(&mut self, xs: &[Image]) -> Result {
- self.visual.encode(xs)
+ let ys = self.processor.process_images(xs)?;
+ self.batch = xs.len(); // update
+ let ys = self.visual.run(ys.into())?;
+
+ Ok(ys[0].to_owned())
}
pub fn encode_texts(&mut self, text: Option<&str>) -> Result>> {
let input_ids = self
- .textual
- .processor()
+ .processor
.encode_text_ids(text.unwrap_or_default(), false)?;
Ok(vec![input_ids.clone(); self.batch()])
}
@@ -70,11 +85,11 @@ impl Blip {
let input_ids_attn_mask = X::ones(input_ids_nd.dims());
// decode
- let outputs = self.textual.inference(Xs::from(vec![
+ let outputs = self.textual.run(Xs::from(vec![
input_ids_nd,
input_ids_attn_mask,
image_embeds.clone(),
- X::ones(&[self.visual().batch(), image_embeds.dims()[1]]), // image_embeds_attn_mask
+ X::ones(&[self.batch(), image_embeds.dims()[1]]),
]))?;
// decode each token for each batch
@@ -102,7 +117,7 @@ impl Blip {
}
// batch decode
- let texts = self.textual.processor().decode_tokens_batch(
+ let texts = self.processor.decode_tokens_batch(
&token_ids
.into_iter()
.map(|v| v.into_iter().map(|x| x as u32).collect::>())
@@ -114,7 +129,6 @@ impl Blip {
.into_iter()
.map(|x| Y::default().with_texts(&[&x]))
.collect::>();
- // .into();
Ok(ys)
}
@@ -122,8 +136,4 @@ impl Blip {
pub fn summary(&mut self) {
self.ts.summary();
}
-
- pub fn batch(&self) -> usize {
- self.visual.batch() as _
- }
}
diff --git a/src/models/clip/config.rs b/src/models/clip/config.rs
index 0454261..d0712b2 100644
--- a/src/models/clip/config.rs
+++ b/src/models/clip/config.rs
@@ -1,71 +1,57 @@
-use crate::Kind;
-
/// Model configuration for `CLIP`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn clip() -> Self {
Self::default()
- .with_model_name("clip")
- .with_model_ixx(0, 0, 1.into())
- }
-
- pub fn clip_visual() -> Self {
- Self::clip()
- .with_model_kind(Kind::Vision)
- .with_model_ixx(0, 2, 224.into())
- .with_model_ixx(0, 3, 224.into())
+ .with_name("clip")
+ .with_batch_size_all(1)
+ .with_visual_ixx(0, 1, 3.into())
+ .with_visual_ixx(0, 2, 224.into())
+ .with_visual_ixx(0, 3, 224.into())
.with_image_mean(&[0.48145466, 0.4578275, 0.40821073])
.with_image_std(&[0.26862954, 0.2613026, 0.2757771])
- }
-
- pub fn clip_textual() -> Self {
- Self::clip()
- .with_model_kind(Kind::Language)
.with_model_max_length(77)
+ .with_tokenizer_file("clip/tokenizer.json")
+ .with_tokenizer_config_file("clip/tokenizer_config.json")
+ .with_special_tokens_map_file("clip/special_tokens_map.json")
+ .with_config_file("clip/config.json")
}
- pub fn clip_vit_b16_visual() -> Self {
- Self::clip_visual().with_model_file("vit-b16-visual.onnx")
+ pub fn clip_vit_b16() -> Self {
+ Self::clip()
+ .with_visual_file("vit-b16-visual.onnx")
+ .with_textual_file("vit-b16-textual.onnx")
}
- pub fn clip_vit_b16_textual() -> Self {
- Self::clip_textual().with_model_file("vit-b16-textual.onnx")
+ pub fn clip_vit_b32() -> Self {
+ Self::clip()
+ .with_visual_file("vit-b32-visual.onnx")
+ .with_textual_file("vit-b32-textual.onnx")
}
- pub fn clip_vit_b32_visual() -> Self {
- Self::clip_visual().with_model_file("vit-b32-visual.onnx")
+ pub fn clip_vit_l14() -> Self {
+ Self::clip()
+ .with_visual_file("vit-l14-visual.onnx")
+ .with_textual_file("vit-l14-textual.onnx")
}
- pub fn clip_vit_b32_textual() -> Self {
- Self::clip_textual().with_model_file("vit-b32-textual.onnx")
- }
-
- pub fn clip_vit_l14_visual() -> Self {
- Self::clip_visual().with_model_file("vit-l14-visual.onnx")
- }
-
- pub fn clip_vit_l14_textual() -> Self {
- Self::clip_textual().with_model_file("vit-l14-textual.onnx")
+ pub fn jina_clip() -> Self {
+ Self::default()
+ .with_name("jina-clip-v1")
+ .with_batch_size_all(1)
+ .with_visual_ixx(0, 1, 3.into())
+ .with_visual_ixx(0, 2, 224.into())
+ .with_visual_ixx(0, 3, 224.into())
+ .with_image_mean(&[0.48145466, 0.4578275, 0.40821073])
+ .with_image_std(&[0.26862954, 0.2613026, 0.2757771])
+ .with_tokenizer_file("jina-clip-v1/tokenizer.json")
+ .with_tokenizer_config_file("jina-clip-v1/tokenizer_config.json")
+ .with_special_tokens_map_file("jina-clip-v1/special_tokens_map.json")
+ .with_config_file("jina-clip-v1/config.json")
}
pub fn jina_clip_v1() -> Self {
- Self::default()
- .with_model_name("jina-clip-v1")
- .with_model_ixx(0, 0, 1.into())
- }
-
- pub fn jina_clip_v1_visual() -> Self {
- Self::jina_clip_v1()
- .with_model_kind(Kind::Vision)
- .with_model_ixx(0, 2, 224.into())
- .with_model_ixx(0, 3, 224.into())
- .with_image_mean(&[0.48145466, 0.4578275, 0.40821073])
- .with_image_std(&[0.26862954, 0.2613026, 0.2757771])
- .with_model_file("visual.onnx")
- }
-
- pub fn jina_clip_v1_textual() -> Self {
- Self::jina_clip_v1()
- .with_model_kind(Kind::Language)
- .with_model_file("textual.onnx")
+ Self::jina_clip()
+ .with_visual_file("visual.onnx")
+ .with_textual_file("textual.onnx")
}
}
diff --git a/src/models/clip/impl.rs b/src/models/clip/impl.rs
index 34f41ec..e3231b9 100644
--- a/src/models/clip/impl.rs
+++ b/src/models/clip/impl.rs
@@ -2,11 +2,12 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::Array2;
-use crate::{elapsed, Engine, Image, Options, Processor, Ts, Xs, X};
+use crate::{elapsed, Engine, Image, ModelConfig, Processor, Ts, X};
#[derive(Debug, Builder)]
-pub struct ClipVisual {
- engine: Engine,
+pub struct Clip {
+ visual: Engine,
+ textual: Engine,
height: usize,
width: usize,
batch: usize,
@@ -14,22 +15,23 @@ pub struct ClipVisual {
ts: Ts,
}
-impl ClipVisual {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
- let (batch, height, width, ts) = (
- engine.batch().opt(),
- engine.try_height().unwrap_or(&224.into()).opt(),
- engine.try_width().unwrap_or(&224.into()).opt(),
- engine.ts.clone(),
+impl Clip {
+ pub fn new(config: ModelConfig) -> Result {
+ let visual = Engine::try_from_config(&config.visual)?;
+ let textual = Engine::try_from_config(&config.textual)?;
+ let (batch, height, width) = (
+ visual.batch().opt(),
+ visual.try_height().unwrap_or(&224.into()).opt(),
+ visual.try_width().unwrap_or(&224.into()).opt(),
);
- let processor = options
- .to_processor()?
+ let ts = Ts::merge(&[visual.ts(), textual.ts()]);
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
Ok(Self {
- engine,
+ textual,
+ visual,
height,
width,
batch,
@@ -38,111 +40,39 @@ impl ClipVisual {
})
}
- pub fn preprocess(&mut self, xs: &[Image]) -> Result {
- let x = self.processor.process_images(xs)?;
-
- Ok(x.into())
- }
-
- pub fn inference(&mut self, xs: Xs) -> Result {
- self.engine.run(xs)
- }
-
pub fn encode_images(&mut self, xs: &[Image]) -> Result {
- let xs = elapsed!("visual-preprocess", self.ts, { self.preprocess(xs)? });
- let xs = elapsed!("visual-inference", self.ts, { self.inference(xs)? });
+ let xs = elapsed!("visual-preprocess", self.ts, {
+ self.processor.process_images(xs)?
+ });
+ let xs = elapsed!("visual-inference", self.ts, { self.visual.run(xs.into())? });
let x = elapsed!("visual-postprocess", self.ts, { xs[0].to_owned() });
Ok(x)
}
-}
-
-#[derive(Debug, Builder)]
-pub struct ClipTextual {
- engine: Engine,
- batch: usize,
- processor: Processor,
- ts: Ts,
-}
-
-impl ClipTextual {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
- let (batch, ts) = (engine.batch().opt(), engine.ts.clone());
- let processor = options.to_processor()?;
-
- Ok(Self {
- engine,
- batch,
- processor,
- ts,
- })
- }
-
- pub fn preprocess(&self, xs: &[&str]) -> Result {
- let encodings: Vec = self
- .processor
- .encode_texts_ids(xs, false)? // skip_special_tokens
- .into_iter()
- .flatten()
- .collect();
-
- let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)?
- .into_dyn()
- .into();
-
- Ok(x.into())
- }
-
- pub fn inference(&mut self, xs: Xs) -> Result {
- self.engine.run(xs)
- }
pub fn encode_texts(&mut self, xs: &[&str]) -> Result {
- let xs = elapsed!("textual-preprocess", self.ts, { self.preprocess(xs)? });
- let xs = elapsed!("textual-inference", self.ts, { self.inference(xs)? });
+ let xs = elapsed!("textual-preprocess", self.ts, {
+ let encodings: Vec = self
+ .processor
+ .encode_texts_ids(xs, false)?
+ .into_iter()
+ .flatten()
+ .collect();
+
+ let x: X = Array2::from_shape_vec((xs.len(), encodings.len() / xs.len()), encodings)?
+ .into_dyn()
+ .into();
+ x
+ });
+ let xs = elapsed!("textual-inference", self.ts, {
+ self.textual.run(xs.into())?
+ });
let x = elapsed!("textual-postprocess", self.ts, { xs[0].to_owned() });
Ok(x)
}
-}
-
-#[derive(Debug, Builder)]
-pub struct Clip {
- textual: ClipTextual,
- visual: ClipVisual,
- ts: Ts,
-}
-
-impl Clip {
- pub fn new(options_visual: Options, options_textual: Options) -> Result {
- let visual = ClipVisual::new(options_visual)?;
- let textual = ClipTextual::new(options_textual)?;
- // let ts = Ts::merge(&[visual.engine().ts(), textual.engine().ts()]);
- let ts = Ts::default();
-
- Ok(Self {
- textual,
- visual,
- ts,
- })
- }
-
- pub fn encode_images(&mut self, xs: &[Image]) -> Result {
- let x = elapsed!("encode_images", self.ts, { self.visual.encode_images(xs)? });
- Ok(x)
- }
-
- pub fn encode_texts(&mut self, xs: &[&str]) -> Result {
- let x = elapsed!("encode_texts", self.ts, { self.textual.encode_texts(xs)? });
- Ok(x)
- }
pub fn summary(&mut self) {
- // self.ts.clear();
- // self.ts = Ts::merge(&[&self.ts, self.visual.ts(), self.textual.ts()]);
self.ts.summary();
- self.visual.ts().summary();
- self.textual.ts().summary();
}
}
diff --git a/src/models/convnext/config.rs b/src/models/convnext/config.rs
index 0ca435d..30d0074 100644
--- a/src/models/convnext/config.rs
+++ b/src/models/convnext/config.rs
@@ -1,10 +1,10 @@
use crate::NAMES_IMAGENET_1K;
/// Model configuration for `ConvNeXt`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn convnext() -> Self {
Self::default()
- .with_model_name("convnext")
+ .with_name("convnext")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
diff --git a/src/models/d_fine/config.rs b/src/models/d_fine/config.rs
index ea2cea7..16de585 100644
--- a/src/models/d_fine/config.rs
+++ b/src/models/d_fine/config.rs
@@ -1,7 +1,7 @@
/// Model configuration for `d_fine`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn d_fine() -> Self {
- Self::rtdetr().with_model_name("d-fine")
+ Self::rtdetr().with_name("d-fine")
}
pub fn d_fine_n_coco() -> Self {
diff --git a/src/models/db/config.rs b/src/models/db/config.rs
index c594526..0493237 100644
--- a/src/models/db/config.rs
+++ b/src/models/db/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for [DB](https://github.com/MhLiao/DB) and [PaddleOCR-Det](https://github.com/PaddlePaddle/PaddleOCR)
-impl crate::Options {
+impl crate::ModelConfig {
pub fn db() -> Self {
Self::default()
- .with_model_name("db")
+ .with_name("db")
.with_model_ixx(0, 0, (1, 1, 8).into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, (608, 960, 1600).into())
@@ -11,7 +11,7 @@ impl crate::Options {
.with_normalize(true)
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
- .with_binary_thresh(0.2)
+ .with_db_binary_thresh(0.2)
.with_class_confs(&[0.35])
.with_min_width(5.0)
.with_min_height(12.0)
diff --git a/src/models/db/impl.rs b/src/models/db/impl.rs
index 651fa87..af7d874 100644
--- a/src/models/db/impl.rs
+++ b/src/models/db/impl.rs
@@ -4,7 +4,8 @@ use ndarray::Axis;
use rayon::prelude::*;
use crate::{
- elapsed, DynConf, Engine, Hbb, Image, Mask, Obb, Ops, Options, Polygon, Processor, Ts, Xs, Y,
+ elapsed, DynConf, Engine, Hbb, Image, Mask, ModelConfig, Obb, Ops, Polygon, Processor, Ts, Xs,
+ Y,
};
#[derive(Debug, Builder)]
@@ -24,8 +25,8 @@ pub struct DB {
}
impl DB {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts, spec) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&960.into()).opt(),
@@ -33,15 +34,14 @@ impl DB {
engine.ts.clone(),
engine.spec().to_owned(),
);
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let confs = DynConf::new(options.class_confs(), 1);
- let binary_thresh = options.binary_thresh().unwrap_or(0.2);
- let unclip_ratio = options.unclip_ratio().unwrap_or(1.5);
- let min_width = options.min_width().unwrap_or(12.0);
- let min_height = options.min_height().unwrap_or(5.0);
+ let confs = DynConf::new(config.class_confs(), 1);
+ let binary_thresh = config.db_binary_thresh().unwrap_or(0.2);
+ let unclip_ratio = config.db_unclip_ratio().unwrap_or(1.5);
+ let min_width = config.min_width().unwrap_or(12.0);
+ let min_height = config.min_height().unwrap_or(5.0);
Ok(Self {
engine,
diff --git a/src/models/deim/config.rs b/src/models/deim/config.rs
index 10c4a0a..81177ae 100644
--- a/src/models/deim/config.rs
+++ b/src/models/deim/config.rs
@@ -1,7 +1,7 @@
/// Model configuration for `DEIM`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn deim() -> Self {
- Self::d_fine().with_model_name("deim")
+ Self::d_fine().with_name("deim")
}
pub fn deim_dfine_s_coco() -> Self {
diff --git a/src/models/deit/config.rs b/src/models/deit/config.rs
index 8be8000..1f7c3b9 100644
--- a/src/models/deit/config.rs
+++ b/src/models/deit/config.rs
@@ -1,10 +1,10 @@
use crate::NAMES_IMAGENET_1K;
/// Model configuration for `DeiT`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn deit() -> Self {
Self::default()
- .with_model_name("deit")
+ .with_name("deit")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
diff --git a/src/models/depth_anything/config.rs b/src/models/depth_anything/config.rs
index d19344d..e8adefd 100644
--- a/src/models/depth_anything/config.rs
+++ b/src/models/depth_anything/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for `DepthAnything`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn depth_anything() -> Self {
Self::default()
- .with_model_name("depth-anything")
+ .with_name("depth-anything")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, (384, 518, 1024).into())
@@ -14,26 +14,26 @@ impl crate::Options {
}
pub fn depth_anything_s() -> Self {
- Self::depth_anything().with_model_scale(crate::Scale::S)
+ Self::depth_anything().with_scale(crate::Scale::S)
}
pub fn depth_anything_v1() -> Self {
- Self::depth_anything().with_model_version(1.into())
+ Self::depth_anything().with_version(1.into())
}
pub fn depth_anything_v2() -> Self {
- Self::depth_anything().with_model_version(2.into())
+ Self::depth_anything().with_version(2.into())
}
pub fn depth_anything_v1_small() -> Self {
Self::depth_anything_v1()
- .with_model_scale(crate::Scale::S)
+ .with_scale(crate::Scale::S)
.with_model_file("v1-s.onnx")
}
pub fn depth_anything_v2_small() -> Self {
Self::depth_anything_v2()
- .with_model_scale(crate::Scale::S)
+ .with_scale(crate::Scale::S)
.with_model_file("v2-s.onnx")
}
}
diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs
index f094cc6..778fcc8 100644
--- a/src/models/depth_anything/impl.rs
+++ b/src/models/depth_anything/impl.rs
@@ -1,7 +1,7 @@
use aksr::Builder;
use anyhow::Result;
-use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y};
#[derive(Debug, Builder)]
pub struct DepthAnything {
@@ -15,8 +15,8 @@ pub struct DepthAnything {
}
impl DepthAnything {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
@@ -25,9 +25,7 @@ impl DepthAnything {
engine.try_width().unwrap_or(&518.into()).opt(),
engine.ts().clone(),
);
-
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
diff --git a/src/models/depth_pro/config.rs b/src/models/depth_pro/config.rs
index 451682e..5cbddb1 100644
--- a/src/models/depth_pro/config.rs
+++ b/src/models/depth_pro/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for `DepthPro`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn depth_pro() -> Self {
Self::default()
- .with_model_name("depth-pro")
+ .with_name("depth-pro")
.with_model_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 1536.into())
@@ -12,16 +12,4 @@ impl crate::Options {
.with_resize_mode(crate::ResizeMode::FitExact)
.with_normalize(true)
}
-
- // pub fn depth_pro_q4f16() -> Self {
- // Self::depth_pro().with_model_file("q4f16.onnx")
- // }
-
- // pub fn depth_pro_fp16() -> Self {
- // Self::depth_pro().with_model_file("fp16.onnx")
- // }
-
- // pub fn depth_pro_bnb4() -> Self {
- // Self::depth_pro().with_model_file("bnb4.onnx")
- // }
}
diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs
index 6e1d254..6301437 100644
--- a/src/models/depth_pro/impl.rs
+++ b/src/models/depth_pro/impl.rs
@@ -2,7 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;
-use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct DepthPro {
@@ -16,8 +16,8 @@ pub struct DepthPro {
}
impl DepthPro {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -25,8 +25,7 @@ impl DepthPro {
engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
diff --git a/src/models/dinov2/config.rs b/src/models/dinov2/config.rs
index abf7696..60df927 100644
--- a/src/models/dinov2/config.rs
+++ b/src/models/dinov2/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for `DINOv2`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn dinov2() -> Self {
Self::default()
- .with_model_name("dinov2")
+ .with_name("dinov2")
.with_model_ixx(0, 0, (1, 1, 8).into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
@@ -16,13 +16,13 @@ impl crate::Options {
pub fn dinov2_small() -> Self {
Self::dinov2()
- .with_model_scale(crate::Scale::S)
+ .with_scale(crate::Scale::S)
.with_model_file("s.onnx")
}
pub fn dinov2_base() -> Self {
Self::dinov2()
- .with_model_scale(crate::Scale::B)
+ .with_scale(crate::Scale::B)
.with_model_file("b.onnx")
}
}
diff --git a/src/models/dinov2/impl.rs b/src/models/dinov2/impl.rs
index 813851d..c9a070c 100644
--- a/src/models/dinov2/impl.rs
+++ b/src/models/dinov2/impl.rs
@@ -1,7 +1,7 @@
use aksr::Builder;
use anyhow::Result;
-use crate::{elapsed, Engine, Image, Options, Processor, Scale, Ts, Xs, X};
+use crate::{elapsed, Engine, Image, ModelConfig, Processor, Scale, Ts, Xs, X};
#[derive(Builder, Debug)]
pub struct DINOv2 {
@@ -15,15 +15,15 @@ pub struct DINOv2 {
}
impl DINOv2 {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&384.into()).opt(),
engine.try_width().unwrap_or(&384.into()).opt(),
engine.ts.clone(),
);
- let dim = match options.model_scale() {
+ let dim = match &config.scale {
Some(Scale::S) => 384,
Some(Scale::B) => 768,
Some(Scale::L) => 1024,
@@ -31,8 +31,7 @@ impl DINOv2 {
Some(x) => anyhow::bail!("Unsupported scale: {:?}", x),
None => anyhow::bail!("No model scale specified"),
};
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
diff --git a/src/models/fast/config.rs b/src/models/fast/config.rs
index 6ec39c4..02c4250 100644
--- a/src/models/fast/config.rs
+++ b/src/models/fast/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://github.com/czczup/FAST)
-impl crate::Options {
+impl crate::ModelConfig {
pub fn fast() -> Self {
Self::db()
- .with_model_name("fast")
+ .with_name("fast")
.with_image_mean(&[0.798, 0.785, 0.772])
.with_image_std(&[0.264, 0.2749, 0.287])
}
diff --git a/src/models/fastvit/config.rs b/src/models/fastvit/config.rs
index 71cdd31..10038ef 100644
--- a/src/models/fastvit/config.rs
+++ b/src/models/fastvit/config.rs
@@ -1,10 +1,10 @@
use crate::NAMES_IMAGENET_1K;
/// Model configuration for `FastViT`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn fastvit() -> Self {
Self::default()
- .with_model_name("fastvit")
+ .with_name("fastvit")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
diff --git a/src/models/florence2/config.rs b/src/models/florence2/config.rs
index 8ef74ac..e1fac46 100644
--- a/src/models/florence2/config.rs
+++ b/src/models/florence2/config.rs
@@ -1,59 +1,31 @@
/// Model configuration for `Florence2`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn florence2() -> Self {
Self::default()
- .with_model_name("florence2")
- .with_batch_size(1)
- }
-
- pub fn florence2_visual() -> Self {
- Self::florence2()
- .with_model_kind(crate::Kind::Vision)
- .with_model_ixx(0, 2, 768.into())
- .with_model_ixx(0, 3, 768.into())
+ .with_name("florence2")
+ .with_batch_size_all(1)
+ .with_visual_ixx(0, 1, 3.into())
+ .with_visual_ixx(0, 2, 768.into())
+ .with_visual_ixx(0, 3, 768.into())
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
- .with_resize_filter("Bilinear")
- .with_normalize(true)
}
- pub fn florence2_textual() -> Self {
- Self::florence2().with_model_kind(crate::Kind::Language)
+ pub fn florence2_base() -> Self {
+ Self::florence2()
+ .with_scale(crate::Scale::B)
+ .with_visual_file("base-vision-encoder.onnx")
+ .with_textual_file("base-embed-tokens.onnx")
+ .with_textual_encoder_file("base-encoder.onnx")
+ .with_textual_decoder_file("base-decoder.onnx")
+ .with_textual_decoder_merged_file("base-decoder-merged.onnx")
+ .with_tokenizer_file("florence2/tokenizer.json")
+ .with_config_file("florence2/config.json")
+ .with_special_tokens_map_file("florence2/special_tokens_map.json")
+ .with_tokenizer_config_file("florence2/tokenizer_config.json")
}
- pub fn florence2_visual_base() -> Self {
- Self::florence2_visual().with_model_scale(crate::Scale::B)
- }
-
- pub fn florence2_textual_base() -> Self {
- Self::florence2_textual().with_model_scale(crate::Scale::B)
- }
-
- pub fn florence2_visual_large() -> Self {
- Self::florence2_visual().with_model_scale(crate::Scale::L)
- }
-
- pub fn florence2_textual_large() -> Self {
- Self::florence2_textual().with_model_scale(crate::Scale::L)
- }
-
- pub fn florence2_visual_encoder_base() -> Self {
- Self::florence2_visual_base().with_model_file("base-vision-encoder.onnx")
- }
-
- pub fn florence2_textual_embed_base() -> Self {
- Self::florence2_textual_base().with_model_file("base-embed-tokens.onnx")
- }
-
- pub fn florence2_texual_encoder_base() -> Self {
- Self::florence2_textual_base().with_model_file("base-encoder.onnx")
- }
-
- pub fn florence2_texual_decoder_base() -> Self {
- Self::florence2_textual_base().with_model_file("base-decoder.onnx")
- }
-
- pub fn florence2_texual_decoder_merged_base() -> Self {
- Self::florence2_textual_base().with_model_file("base-decoder-merged.onnx")
+ pub fn florence2_large() -> Self {
+ todo!()
}
}
diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs
index 6a56775..b1f68b0 100644
--- a/src/models/florence2/impl.rs
+++ b/src/models/florence2/impl.rs
@@ -4,51 +4,59 @@ use ndarray::{s, Axis};
use rayon::prelude::*;
use crate::{
- elapsed,
- models::{BaseModelTextual, BaseModelVisual, Quantizer},
- Hbb, Image, LogitsSampler, Options, Polygon, Scale, Task, Ts, Xs, X, Y,
+ elapsed, models::Quantizer, Engine, Hbb, Image, LogitsSampler, ModelConfig, Polygon, Processor,
+ Scale, Task, Ts, Xs, X, Y,
};
#[derive(Debug, Builder)]
pub struct Florence2 {
- pub vision_encoder: BaseModelVisual,
- pub text_embed: BaseModelTextual,
- pub encoder: BaseModelTextual,
- pub decoder: BaseModelTextual,
- pub decoder_merged: BaseModelTextual,
+ pub vision_encoder: Engine,
+ pub text_embed: Engine,
+ pub encoder: Engine,
+ pub decoder: Engine,
+ pub decoder_merged: Engine,
ts: Ts,
quantizer: Quantizer,
max_length: usize,
eos_token_id: u32,
decoder_start_token_id: u32,
n_kvs: usize,
+ height: usize,
+ width: usize,
+ batch: usize,
+ processor: Processor,
}
impl Florence2 {
- pub fn new(
- options_vision_encoder: Options,
- options_text_embed: Options,
- options_encoder: Options,
- options_decoder: Options,
- options_decoder_merged: Options,
- ) -> Result {
- let vision_encoder = BaseModelVisual::new(options_vision_encoder)?;
- let text_embed = BaseModelTextual::new(options_text_embed)?;
- let encoder = BaseModelTextual::new(options_encoder)?;
- let decoder = BaseModelTextual::new(options_decoder)?;
- let decoder_merged = BaseModelTextual::new(options_decoder_merged)?;
+ pub fn new(config: ModelConfig) -> Result {
+ let vision_encoder = Engine::try_from_config(&config.visual)?;
+ let text_embed = Engine::try_from_config(&config.textual)?;
+ let encoder = Engine::try_from_config(&config.textual_encoder)?;
+ let decoder = Engine::try_from_config(&config.textual_decoder)?;
+ let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?;
+
+ let (batch, height, width) = (
+ vision_encoder.batch().opt(),
+ vision_encoder.try_height().unwrap_or(&1024.into()).opt(),
+ vision_encoder.try_width().unwrap_or(&1024.into()).opt(),
+ );
+
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
let quantizer = Quantizer::default();
let ts = Ts::merge(&[
- vision_encoder.engine().ts(),
- text_embed.engine().ts(),
- encoder.engine().ts(),
- decoder.engine().ts(),
- decoder_merged.engine().ts(),
+ vision_encoder.ts(),
+ text_embed.ts(),
+ encoder.ts(),
+ decoder.ts(),
+ decoder_merged.ts(),
]);
let max_length = 1024;
let eos_token_id = 2;
let decoder_start_token_id = 2;
- let n_kvs = match decoder.scale() {
+ let n_kvs = match config.scale {
Some(Scale::B) => 6,
Some(Scale::L) => 12,
_ => unimplemented!(),
@@ -66,6 +74,10 @@ impl Florence2 {
eos_token_id,
decoder_start_token_id,
n_kvs,
+ batch,
+ height,
+ width,
+ processor,
})
}
@@ -97,12 +109,12 @@ impl Florence2 {
.map(|im| {
let text = Self::process_task(task, im.height() as _, im.width() as _)
.prompt_for_florence2()?;
- let ids = self.text_embed.processor().encode_text_ids(&text, true)?;
+ let ids = self.processor.encode_text_ids(&text, true)?;
X::from(ids).insert_axis(0)
})
.collect::, _>>()?;
let x = X::concat(&xs, 0)?;
- let xs = self.text_embed.inference(x.into())?;
+ let xs = self.text_embed.run(x.into())?;
let x = xs[0].to_owned();
Ok(x)
@@ -110,7 +122,10 @@ impl Florence2 {
pub fn forward(&mut self, xs_visual: &[Image], x_textual: &Task) -> Result> {
let visual_embeddings = elapsed!("visual-encode", self.ts, {
- self.vision_encoder.encode(xs_visual)?
+ let xs = self.processor.process_images(xs_visual)?;
+ self.batch = xs_visual.len(); // update
+ let xs = self.vision_encoder.run(xs.into())?;
+ xs[0].to_owned()
});
let textual_embedding = elapsed!("textual-encode", self.ts, {
@@ -141,7 +156,7 @@ impl Florence2 {
let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]);
// encoder
- let last_hidden_state = self.encoder.inference(Xs::from(vec![
+ let last_hidden_state = self.encoder.run(Xs::from(vec![
attention_mask.clone(),
inputs_embeds.clone(),
]))?[0]
@@ -150,7 +165,7 @@ impl Florence2 {
// decoder
let inputs_embeds = inputs_embeds.slice(s![.., -1.., ..]);
let inputs_embeds = X::from(inputs_embeds.to_owned().into_dyn());
- let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
+ let mut decoder_outputs = self.decoder.run(Xs::from(vec![
attention_mask.clone(),
last_hidden_state.clone(),
inputs_embeds,
@@ -215,7 +230,7 @@ impl Florence2 {
// decode
let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?;
- let inputs_embeds = &self.text_embed.inference(Xs::from(next_tokens))?[0].clone();
+ let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone();
let use_cache = X::ones(&[1]);
let mut xs = vec![
attention_mask.clone(),
@@ -229,13 +244,13 @@ impl Florence2 {
xs.push(encoder_kvs[i * 2 + 1].clone());
}
xs.push(use_cache);
- decoder_outputs = self.decoder_merged.inference(xs.into())?;
+ decoder_outputs = self.decoder_merged.run(xs.into())?;
}
// batch decode
let texts = self
- .text_embed
- .processor()
+ // .text_embed
+ .processor
.decode_tokens_batch(&token_ids, false)?;
Ok(texts)
@@ -416,10 +431,6 @@ impl Florence2 {
Ok(ys)
}
- pub fn batch(&self) -> usize {
- self.vision_encoder.batch() as _
- }
-
pub fn summary(&mut self) {
self.ts.summary();
}
diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs
index ce7096d..1ff2b34 100644
--- a/src/models/grounding_dino/config.rs
+++ b/src/models/grounding_dino/config.rs
@@ -1,9 +1,8 @@
/// Model configuration for `GroundingDino`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn grounding_dino() -> Self {
Self::default()
- .with_model_name("grounding-dino")
- .with_model_kind(crate::Kind::VisionLanguage)
+ .with_name("grounding-dino")
.with_model_ixx(0, 0, 1.into()) // TODO: current onnx model does not support bs > 1
.with_model_ixx(0, 2, 800.into()) // TODO: matters
.with_model_ixx(0, 3, 1200.into()) // TODO: matters
@@ -11,9 +10,10 @@ impl crate::Options {
.with_resize_filter("CatmullRom")
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
- .with_normalize(true)
- .with_class_confs(&[0.25])
- .with_text_confs(&[0.25])
+ .with_tokenizer_file("grounding-dino/tokenizer.json")
+ .with_config_file("grounding-dino/config.json")
+ .with_special_tokens_map_file("grounding-dino/special_tokens_map.json")
+ .with_tokenizer_config_file("grounding-dino/tokenizer_config.json")
}
pub fn grounding_dino_tiny() -> Self {
diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs
index d03e948..d1082d9 100644
--- a/src/models/grounding_dino/impl.rs
+++ b/src/models/grounding_dino/impl.rs
@@ -4,7 +4,7 @@ use ndarray::{s, Array2, Axis};
use rayon::prelude::*;
use std::fmt::Write;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y};
#[derive(Builder, Debug)]
pub struct GroundingDINO {
@@ -24,8 +24,8 @@ pub struct GroundingDINO {
}
impl GroundingDINO {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -33,11 +33,8 @@ impl GroundingDINO {
engine.try_width().unwrap_or(&1200.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let class_names = options
+
+ let class_names = config
.text_names
.as_ref()
.and_then(|v| {
@@ -48,16 +45,20 @@ impl GroundingDINO {
.collect();
(!v.is_empty()).then_some(v)
})
- .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
+ .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the config. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
write!(&mut acc, "{}.", text).unwrap();
acc
});
+
+ let confs_visual = DynConf::new(config.class_confs(), class_names.len());
+ let confs_textual = DynConf::new(config.text_confs(), class_names.len());
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
let token_ids = processor.encode_text_ids(&text_prompt, true)?;
let tokens = processor.encode_text_tokens(&text_prompt, true)?;
let class_ids_map = Self::process_class_ids(&tokens);
- let confs_visual = DynConf::new(options.class_confs(), class_names.len());
- let confs_textual = DynConf::new(options.text_confs(), class_names.len());
Ok(Self {
engine,
diff --git a/src/models/linknet/config.rs b/src/models/linknet/config.rs
index c1c666a..9c952b1 100644
--- a/src/models/linknet/config.rs
+++ b/src/models/linknet/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/abs/1707.03718)
-impl crate::Options {
+impl crate::ModelConfig {
pub fn linknet() -> Self {
Self::fast()
- .with_model_name("linknet")
+ .with_name("linknet")
.with_image_mean(&[0.798, 0.785, 0.772])
.with_image_std(&[0.264, 0.2749, 0.287])
}
diff --git a/src/models/mobileone/config.rs b/src/models/mobileone/config.rs
index a76f271..1190d46 100644
--- a/src/models/mobileone/config.rs
+++ b/src/models/mobileone/config.rs
@@ -1,10 +1,10 @@
use crate::NAMES_IMAGENET_1K;
/// Model configuration for `MobileOne`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn mobileone() -> Self {
Self::default()
- .with_model_name("mobileone")
+ .with_name("mobileone")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 224.into())
diff --git a/src/models/modnet/config.rs b/src/models/modnet/config.rs
index 05174d2..c72dd53 100644
--- a/src/models/modnet/config.rs
+++ b/src/models/modnet/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for `MODNet`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn modnet() -> Self {
Self::default()
- .with_model_name("modnet")
+ .with_name("modnet")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 2, (416, 512, 800).into())
.with_model_ixx(0, 3, (416, 512, 800).into())
diff --git a/src/models/modnet/impl.rs b/src/models/modnet/impl.rs
index 5322ab1..9dff8b3 100644
--- a/src/models/modnet/impl.rs
+++ b/src/models/modnet/impl.rs
@@ -2,7 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;
-use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct MODNet {
@@ -16,8 +16,8 @@ pub struct MODNet {
}
impl MODNet {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -25,8 +25,7 @@ impl MODNet {
engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
diff --git a/src/models/moondream2/config.rs b/src/models/moondream2/config.rs
index 96d0bf9..d5c7642 100644
--- a/src/models/moondream2/config.rs
+++ b/src/models/moondream2/config.rs
@@ -1,117 +1,47 @@
/// Model configuration for `moondream2`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn moondream2() -> Self {
Self::default()
- .with_model_name("moondream2")
- .with_model_num_dry_run(0)
+ .with_name("moondream2")
+ .with_visual_encoder_ixx(0, 0, (1, 3, 4).into()) // patch count
+ .with_image_mean(&[0.5, 0.5, 0.5])
+ .with_image_std(&[0.5, 0.5, 0.5])
+ .with_resize_mode(crate::ResizeMode::FitExact)
+ .with_resize_filter("catmullrom")
+ .with_visual_projection_ixx(0, 0, 1.into())
+ .with_textual_encoder_ixx(0, 0, 1.into())
+ .with_textual_decoder_ixx(0, 0, 1.into())
+ .with_size_encoder_ixx(0, 0, 1.into())
+ .with_size_decoder_ixx(0, 0, 1.into())
+ .with_coord_encoder_ixx(0, 0, 1.into())
+ .with_coord_decoder_ixx(0, 0, 1.into())
+ .with_tokenizer_file("moondream2/tokenizer.json")
+ .with_tokenizer_config_file("moondream2/tokenizer_config.json")
}
pub fn moondream2_0_5b() -> Self {
- Self::moondream2().with_model_scale(crate::Scale::Billion(0.5))
+ Self::moondream2()
+ .with_scale(crate::Scale::Billion(0.5))
+ .with_visual_encoder_file("0.5b-vision-encoder.onnx")
+ .with_visual_projection_file("0.5b-vision-projection.onnx")
+ .with_textual_decoder_file("0.5b-text-decoder.onnx")
+ .with_textual_encoder_file("0.5b-text-encoder.onnx")
+ .with_coord_encoder_file("0.5b-coord-encoder.onnx")
+ .with_coord_decoder_file("0.5b-coord-decoder.onnx")
+ .with_size_encoder_file("0.5b-size-encoder.onnx")
+ .with_size_decoder_file("0.5b-size-decoder.onnx")
}
- pub fn moondream2_0_5b_vision_encoder() -> Self {
- Self::moondream2_0_5b()
- .with_model_ixx(0, 0, (1, 3, 4).into()) // patch count
- .with_model_kind(crate::Kind::Vision)
- .with_image_mean(&[0.5, 0.5, 0.5])
- .with_image_std(&[0.5, 0.5, 0.5])
- .with_normalize(true)
- .with_resize_mode(crate::ResizeMode::FitExact)
- .with_resize_filter("catmullrom")
- .with_model_file("0.5b-vision-encoder.onnx")
- }
-
- pub fn moondream2_0_5b_vision_projection() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_kind(crate::Kind::Vision)
- .with_model_file("0.5b-vision-projection.onnx")
- }
-
- pub fn moondream2_0_5b_text_decoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_kind(crate::Kind::Language)
- .with_model_file("0.5b-text-decoder.onnx")
- }
-
- pub fn moondream2_0_5b_text_encoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_kind(crate::Kind::Language)
- .with_model_file("0.5b-text-encoder.onnx")
- }
-
- pub fn moondream2_0_5b_coord_encoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_file("0.5b-coord-encoder.onnx")
- }
-
- pub fn moondream2_0_5b_coord_decoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_file("0.5b-coord-decoder.onnx")
- }
-
- pub fn moondream2_0_5b_size_encoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_file("0.5b-size-encoder.onnx")
- }
-
- pub fn moondream2_0_5b_size_decoder() -> Self {
- Self::moondream2_0_5b()
- .with_batch_size(1)
- .with_model_file("0.5b-size-decoder.onnx")
- }
-
- pub fn moondream2_2b_vision_encoder() -> Self {
- Self::moondream2_0_5b_vision_encoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-vision-encoder.onnx")
- }
-
- pub fn moondream2_2b_vision_projection() -> Self {
- Self::moondream2_0_5b_vision_projection()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-vision-projection.onnx")
- }
-
- pub fn moondream2_2b_text_decoder() -> Self {
- Self::moondream2_0_5b_text_decoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-text-decoder.onnx")
- }
-
- pub fn moondream2_2b_text_encoder() -> Self {
- Self::moondream2_0_5b_text_encoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-text-encoder.onnx")
- }
-
- pub fn moondream2_2b_coord_encoder() -> Self {
- Self::moondream2_0_5b_coord_encoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-coord-encoder.onnx")
- }
-
- pub fn moondream2_2b_coord_decoder() -> Self {
- Self::moondream2_0_5b_coord_decoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-coord-decoder.onnx")
- }
-
- pub fn moondream2_2b_size_encoder() -> Self {
- Self::moondream2_0_5b_size_encoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-size-encoder.onnx")
- }
-
- pub fn moondream2_2b_size_decoder() -> Self {
- Self::moondream2_0_5b_size_decoder()
- .with_model_scale(crate::Scale::Billion(2.))
- .with_model_file("2b-size-decoder.onnx")
+ pub fn moondream2_2b() -> Self {
+ Self::moondream2()
+ .with_scale(crate::Scale::Billion(2.))
+ .with_visual_encoder_file("2b-vision-encoder.onnx")
+ .with_visual_projection_file("2b-vision-projection.onnx")
+ .with_textual_decoder_file("2b-text-decoder.onnx")
+ .with_textual_encoder_file("2b-text-encoder.onnx")
+ .with_coord_encoder_file("2b-coord-encoder.onnx")
+ .with_coord_decoder_file("2b-coord-decoder.onnx")
+ .with_size_encoder_file("2b-size-encoder.onnx")
+ .with_size_decoder_file("2b-size-decoder.onnx")
}
}
diff --git a/src/models/moondream2/impl.rs b/src/models/moondream2/impl.rs
index 55a0942..30bab95 100644
--- a/src/models/moondream2/impl.rs
+++ b/src/models/moondream2/impl.rs
@@ -5,66 +5,57 @@ use ndarray::{s, Array, Array2, Array3, Axis, IxDyn};
use ndarray_npy::ReadNpyExt;
use crate::{
- BaseModelTextual, DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, Options, Processor,
- Scale, Task, Ts, Xs, X, Y,
+ DType, Engine, Hbb, Hub, Image, Keypoint, LogitsSampler, ModelConfig, Processor, Scale, Task,
+ Xs, X, Y,
};
#[derive(Builder, Debug)]
pub struct Moondream2 {
- vision_encoder: VisionEncoder,
- vision_projection: VisionProjection,
- pub text_decoder: BaseModelTextual,
- text_encoder: BaseModelTextual,
- coord_decoder: Option,
- coord_encoder: Option,
- size_decoder: Option,
- size_encoder: Option,
+ vision_encoder: Engine,
+ vision_projection: Engine,
+ text_decoder: Engine,
+ text_encoder: Engine,
+ coord_decoder: Option,
+ coord_encoder: Option,
+ size_decoder: Option,
+ size_encoder: Option,
initial_kv_cache: X, // TODO: use f16
scale: Scale,
dtype: DType,
max_length: usize,
eos_token_id: u32,
max_objects: usize,
+ num_patch: usize,
+ patch_size: usize,
+ processor: Processor,
+ seq_len: usize,
}
impl Moondream2 {
- // TODO
- #[allow(clippy::too_many_arguments)]
- pub fn new(
- options_vision_encoder: Options,
- options_vision_projection: Options,
- options_text_encoder: Options,
- options_text_decoder: Options,
- options_coord_encoder: Option,
- options_coord_decoder: Option,
- options_size_encoder: Option,
- options_size_decoder: Option,
- ) -> Result {
+ pub fn new(config: ModelConfig) -> Result {
let max_length = 2048;
let max_objects = 50;
let eos_token_id = 50256;
- let dtype = options_vision_encoder.model_dtype;
- let scale = options_vision_encoder
- .model_scale
- .clone()
- .unwrap_or(Scale::Billion(0.5));
+ let dtype = config.visual_encoder.dtype;
+ let scale = config.scale.clone().unwrap_or(Scale::Billion(0.5));
let initial_kv_cache: X = KVCache::new(&scale, &dtype)?.0.into();
- let vision_encoder = VisionEncoder::new(options_vision_encoder)?;
- let vision_projection = VisionProjection::new(options_vision_projection)?;
- let text_decoder = BaseModelTextual::new(options_text_decoder)?;
- let text_encoder = BaseModelTextual::new(options_text_encoder)?;
- let coord_decoder = options_coord_decoder
- .map(BaseModelTextual::new)
- .transpose()?;
- let coord_encoder = options_coord_encoder
- .map(BaseModelTextual::new)
- .transpose()?;
- let size_decoder = options_size_decoder
- .map(BaseModelTextual::new)
- .transpose()?;
- let size_encoder = options_size_encoder
- .map(BaseModelTextual::new)
- .transpose()?;
+ let vision_encoder = Engine::try_from_config(&config.visual_encoder)?;
+ let vision_projection = Engine::try_from_config(&config.visual_projection)?;
+ let text_decoder = Engine::try_from_config(&config.textual_decoder)?;
+ let text_encoder = Engine::try_from_config(&config.textual_encoder)?;
+ let coord_decoder = Engine::try_from_config(&config.coord_decoder).ok();
+ let coord_encoder = Engine::try_from_config(&config.coord_encoder).ok();
+ let size_decoder = Engine::try_from_config(&config.size_decoder).ok();
+ let size_encoder = Engine::try_from_config(&config.size_encoder).ok();
+ let (num_patch, patch_size, _ts) = (
+ vision_encoder.batch().opt(),
+ vision_encoder.try_height().unwrap_or(&378.into()).opt(),
+ vision_encoder.ts.clone(),
+ );
+ let seq_len = vision_projection.inputs_minoptmax[0][1].opt();
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(patch_size as _)
+ .with_image_height(patch_size as _);
Ok(Self {
vision_encoder,
@@ -81,12 +72,16 @@ impl Moondream2 {
eos_token_id,
scale,
dtype,
+ num_patch,
+ patch_size,
+ processor,
+ seq_len,
})
}
pub fn encode_image(&mut self, x: &Image) -> Result {
- let patches_emb = self.vision_encoder.encode(x)?.clone().insert_axis(0)?;
- let image_embedding = self.vision_projection.inference(patches_emb.into())?[0].to_owned();
+ let patches_emb = self.encode(x)?.clone().insert_axis(0)?;
+ let image_embedding = self.vision_projection.run(patches_emb.into())?[0].to_owned();
Ok(image_embedding)
}
@@ -119,12 +114,7 @@ impl Moondream2 {
Task::Vqa(query) => {
let input_ids: Vec<_> = [198., 198., 24361., 25.]
.iter()
- .chain(
- &self
- .text_encoder
- .processor()
- .encode_text_ids(query, false)?,
- )
+ .chain(&self.processor.encode_text_ids(query, false)?)
.chain(&[198., 198., 33706., 25.])
.cloned()
.collect();
@@ -139,8 +129,7 @@ impl Moondream2 {
.iter()
.chain(
&self
- .text_encoder
- .processor()
+ .processor
.encode_text_ids(&format!(" {}", object), false)?,
)
.chain(&[628.])
@@ -156,8 +145,7 @@ impl Moondream2 {
.iter()
.chain(
&self
- .text_encoder
- .processor()
+ .processor
.encode_text_ids(&format!(" {}", object), false)?,
)
.chain(&[628.])
@@ -174,10 +162,10 @@ impl Moondream2 {
fn generate_text(&mut self, input_ids: &[f32], kv_cache: Array) -> Result {
let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?;
- let mut input_embeds = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned();
+ let mut input_embeds = self.text_encoder.run(Xs::from(input_ids))?[0].to_owned();
let logits_sampler = LogitsSampler::new();
let mut token_ids: Vec = Vec::new();
- let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4];
+ let mut pos = self.seq_len + self.initial_kv_cache.shape()[4];
let mut inc = input_embeds.shape()[1];
let mut kv_cache = kv_cache.clone();
@@ -192,7 +180,7 @@ impl Moondream2 {
.into_dyn()
.into(),
]);
- let decoder_outputs = self.text_decoder.inference(input)?;
+ let decoder_outputs = self.text_decoder.run(input)?;
// update
let logits = &decoder_outputs["logits"];
@@ -221,13 +209,10 @@ impl Moondream2 {
// encode
let next_tokens = X::from(vec![token_id as f32]).insert_axis(1)?;
- input_embeds = self.text_encoder.inference(Xs::from(next_tokens))?[0].to_owned();
+ input_embeds = self.text_encoder.run(Xs::from(next_tokens))?[0].to_owned();
}
- let text = self
- .text_encoder
- .processor()
- .decode_tokens(&token_ids, true)?;
+ let text = self.processor.decode_tokens(&token_ids, true)?;
Ok(text)
}
@@ -242,16 +227,16 @@ impl Moondream2 {
let mut y_bboxes: Vec = Vec::new();
let mut y_kpts: Vec> = Vec::new();
let (image_height, image_width) = (
- self.vision_encoder.processor.images_transform_info[0].height_src,
- self.vision_encoder.processor.images_transform_info[0].width_src,
+ self.processor.images_transform_info[0].height_src,
+ self.processor.images_transform_info[0].width_src,
);
- let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4];
+ let mut pos = self.seq_len + self.initial_kv_cache.shape()[4];
let logits_sampler = LogitsSampler::new();
// initial input_embeds
let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?;
- let mut hidden = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned();
+ let mut hidden = self.text_encoder.run(Xs::from(input_ids))?[0].to_owned();
let mut kv_cache = kv_cache;
// generate
@@ -273,12 +258,7 @@ impl Moondream2 {
// cx
let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
- let cx = self
- .coord_decoder
- .as_mut()
- .unwrap()
- .inference(Xs::from(input))?[0]
- .clone(); // [1024]
+ let cx = self.coord_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone(); // [1024]
let ratio = cx.shape()[0] as f32;
let cx = logits_sampler
.decode(cx.as_slice().context("Failed to get slice for `cx`")?)?
@@ -288,7 +268,7 @@ impl Moondream2 {
.coord_encoder
.as_mut()
.unwrap()
- .inference(Xs::from(X::from(vec![cx])))?[0]
+ .run(Xs::from(X::from(vec![cx])))?[0]
.clone()
.insert_axis(0)?
.insert_axis(0)?;
@@ -296,12 +276,7 @@ impl Moondream2 {
// cy
let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?;
let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
- let cy = self
- .coord_decoder
- .as_mut()
- .unwrap()
- .inference(Xs::from(input))?[0]
- .clone();
+ let cy = self.coord_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone();
let ratio = cy.shape()[0] as f32;
let cy = logits_sampler
@@ -313,7 +288,7 @@ impl Moondream2 {
.coord_encoder
.as_mut()
.unwrap()
- .inference(Xs::from(X::from(vec![cy])))?[0]
+ .run(Xs::from(X::from(vec![cy])))?[0]
.clone()
.insert_axis(0)?
.insert_axis(0)?;
@@ -334,12 +309,7 @@ impl Moondream2 {
// wh
let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?;
let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
- let size = self
- .size_decoder
- .as_mut()
- .unwrap()
- .inference(Xs::from(input))?[0]
- .clone(); // [2, 1024]
+ let size = self.size_decoder.as_mut().unwrap().run(Xs::from(input))?[0].clone(); // [2, 1024]
let ratio = size.shape()[1] as f32;
let w = logits_sampler.decode(
@@ -361,7 +331,7 @@ impl Moondream2 {
.size_encoder
.as_mut()
.unwrap()
- .inference(Xs::from(X::from(vec![w, h])))?[0]
+ .run(Xs::from(X::from(vec![w, h])))?[0]
.clone()
.insert_axis(0)?
.insert_axis(0)?; // [1024]
@@ -392,7 +362,7 @@ impl Moondream2 {
}
fn prepare_kv_cache(&mut self, image_embedding: &X) -> Result> {
- let kv_cache_new = self.text_decoder.inference(Xs::from(vec![
+ let kv_cache_new = self.text_decoder.run(Xs::from(vec![
image_embedding.clone(),
self.initial_kv_cache.clone(),
]))?["new_kv_cache"]
@@ -421,7 +391,7 @@ impl Moondream2 {
kv_cache: &mut Array,
pos: &mut usize,
) -> Result {
- let decoder_outputs = self.text_decoder.inference(Xs::from(vec![
+ let decoder_outputs = self.text_decoder.run(Xs::from(vec![
input_embeds.clone(),
kv_cache
.slice(s![.., .., .., .., ..*pos, ..])
@@ -442,38 +412,6 @@ impl Moondream2 {
Ok(decoder_outputs["logits"].to_owned())
}
-}
-
-#[derive(Debug, Builder)]
-pub struct VisionEncoder {
- engine: Engine,
- num_patch: usize,
- patch_size: usize,
- processor: Processor,
- ts: Ts,
-}
-
-impl VisionEncoder {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
- let (num_patch, patch_size, ts) = (
- engine.batch().opt(),
- engine.try_height().unwrap_or(&378.into()).opt(),
- engine.ts.clone(),
- );
- let processor = options
- .to_processor()?
- .with_image_width(patch_size as _)
- .with_image_height(patch_size as _);
-
- Ok(Self {
- engine,
- patch_size,
- num_patch,
- processor,
- ts,
- })
- }
fn create_patches(image: &Image, image_patch_size: usize) -> (Vec, (u32, u32)) {
let mut patches = vec![image.clone()];
@@ -515,10 +453,6 @@ impl VisionEncoder {
(patches, selected_template)
}
- pub fn inference(&mut self, xs: Xs) -> Result {
- self.engine.run(xs)
- }
-
pub fn encode(&mut self, x: &Image) -> Result {
let (patches, selected_template) = Self::create_patches(x, self.patch_size);
let patches = self.processor.process_images(&patches)?;
@@ -526,7 +460,7 @@ impl VisionEncoder {
(selected_template.0 as usize),
(selected_template.1 as usize),
);
- let patch_emb = self.inference(patches.clone().into())?[0].clone();
+ let patch_emb = self.vision_encoder.run(patches.clone().into())?[0].clone();
let patch_emb = patch_emb.clone().0.into_dimensionality::()?;
let patch_emb = Self::process_patch_emb(patch_emb, template)?;
let patch_emb = X::from(patch_emb.into_dyn()); // TODO .insert_axis(x),
@@ -608,30 +542,6 @@ impl VisionEncoder {
}
}
-#[derive(Debug, Builder)]
-pub struct VisionProjection {
- engine: Engine,
- seq_len: usize,
- ts: Ts,
-}
-
-impl VisionProjection {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
- let (seq_len, ts) = (engine.inputs_minoptmax[0][1].opt(), engine.ts.clone());
-
- Ok(Self {
- engine,
- seq_len,
- ts,
- })
- }
-
- pub fn inference(&mut self, xs: Xs) -> Result {
- self.engine.run(xs)
- }
-}
-
#[derive(Builder, Debug)]
struct KVCache(pub Array);
diff --git a/src/models/owl/config.rs b/src/models/owl/config.rs
index cc17da9..2520c1b 100644
--- a/src/models/owl/config.rs
+++ b/src/models/owl/config.rs
@@ -1,11 +1,10 @@
/// Model configuration for `OWLv2`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn owlv2() -> Self {
Self::default()
- .with_model_name("owlv2")
- .with_model_kind(crate::Kind::VisionLanguage)
+ .with_name("owlv2")
// 1st & 3rd: text
- .with_model_ixx(0, 0, (1, 1, 1).into()) // TODO
+ .with_model_ixx(0, 0, (1, 1, 1).into())
.with_model_ixx(0, 1, 1.into())
.with_model_ixx(2, 0, (1, 1, 1).into())
.with_model_ixx(2, 1, 1.into())
@@ -21,6 +20,7 @@ impl crate::Options {
.with_normalize(true)
.with_class_confs(&[0.1])
.with_model_num_dry_run(0)
+ .with_tokenizer_file("owlv2/tokenizer.json")
}
pub fn owlv2_base() -> Self {
diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs
index ca0ef67..2295c77 100644
--- a/src/models/owl/impl.rs
+++ b/src/models/owl/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y};
#[derive(Debug, Builder)]
pub struct OWLv2 {
@@ -22,8 +22,8 @@ pub struct OWLv2 {
}
impl OWLv2 {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&960.into()).opt(),
@@ -31,11 +31,7 @@ impl OWLv2 {
engine.ts.clone(),
);
let spec = engine.spec().to_owned();
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let names: Vec = options
+ let names: Vec = config
.class_names()
.expect("No class names specified.")
.iter()
@@ -44,7 +40,10 @@ impl OWLv2 {
let names_with_prompt: Vec =
names.iter().map(|x| format!("a photo of {}", x)).collect();
let n = names.len();
- let confs = DynConf::new(options.class_confs(), n);
+ let confs = DynConf::new(config.class_confs(), n);
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
let input_ids: Vec = processor
.encode_texts_ids(
&names_with_prompt
diff --git a/src/models/picodet/config.rs b/src/models/picodet/config.rs
index b3988a5..3ca083c 100644
--- a/src/models/picodet/config.rs
+++ b/src/models/picodet/config.rs
@@ -4,11 +4,11 @@ use crate::{
};
/// Model configuration for `PicoDet`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn picodet() -> Self {
Self::default()
- .with_model_name("picodet")
- .with_batch_size(1) // TODO: ONNX model's batch size seems always = 1
+ .with_name("picodet")
+ .with_batch_size_all(1) // TODO: ONNX model's batch size seems always = 1
.with_model_ixx(0, 2, 640.into())
.with_model_ixx(0, 3, 640.into())
.with_model_ixx(1, 0, (1, 1, 8).into())
diff --git a/src/models/picodet/impl.rs b/src/models/picodet/impl.rs
index 6bda13a..0a731e3 100644
--- a/src/models/picodet/impl.rs
+++ b/src/models/picodet/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::Axis;
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y};
#[derive(Debug, Builder)]
pub struct PicoDet {
@@ -19,8 +19,8 @@ pub struct PicoDet {
}
impl PicoDet {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&640.into()).opt(),
@@ -28,15 +28,14 @@ impl PicoDet {
engine.ts.clone(),
);
let spec = engine.spec().to_owned();
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let names = options
+ let names = config
.class_names()
.expect("No class names are specified.")
.to_vec();
- let confs = DynConf::new(options.class_confs(), names.len());
+ let confs = DynConf::new(config.class_confs(), names.len());
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
Ok(Self {
engine,
diff --git a/src/models/pipeline/basemodel.rs b/src/models/pipeline/basemodel.rs
index 062c84f..3a791f0 100644
--- a/src/models/pipeline/basemodel.rs
+++ b/src/models/pipeline/basemodel.rs
@@ -2,8 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use crate::{
- elapsed, DType, Device, Engine, Image, Kind, Options, Processor, Scale, Task, Ts, Version, Xs,
- X,
+ elapsed, DType, Device, Engine, Image, ModelConfig, Processor, Scale, Task, Ts, Version, Xs, X,
};
#[derive(Debug, Builder)]
@@ -20,7 +19,6 @@ pub struct BaseModelVisual {
dtype: DType,
task: Option,
scale: Option,
- kind: Option,
version: Option,
}
@@ -29,8 +27,8 @@ impl BaseModelVisual {
self.ts.summary();
}
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let err_msg = "You need to specify the image height and image width for visual model.";
let (batch, height, width, ts, spec) = (
engine.batch().opt(),
@@ -39,18 +37,16 @@ impl BaseModelVisual {
engine.ts.clone(),
engine.spec().to_owned(),
);
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let device = options.model_device;
- let task = options.model_task;
- let scale = options.model_scale;
- let dtype = options.model_dtype;
- let kind = options.model_kind;
- let name = options.model_name;
- let version = options.model_version;
+ let device = config.model.device;
+ let task = config.task;
+ let scale = config.scale;
+ let dtype = config.model.dtype;
+ let name = config.name;
+ let version = config.version;
Ok(Self {
engine,
@@ -63,7 +59,6 @@ impl BaseModelVisual {
dtype,
task,
scale,
- kind,
device,
version,
name,
@@ -101,7 +96,6 @@ pub struct BaseModelTextual {
dtype: DType,
task: Option,
scale: Option,
- kind: Option,
version: Option,
}
@@ -110,21 +104,20 @@ impl BaseModelTextual {
self.ts.summary();
}
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, ts, spec) = (
engine.batch().opt(),
engine.ts.clone(),
engine.spec().to_owned(),
);
- let processor = options.to_processor()?;
- let device = options.model_device;
- let task = options.model_task;
- let scale = options.model_scale;
- let dtype = options.model_dtype;
- let kind = options.model_kind;
- let name = options.model_name;
- let version = options.model_version;
+ let processor = Processor::try_from_config(&config.processor)?;
+ let device = config.model.device;
+ let dtype = config.model.dtype;
+ let task = config.task;
+ let scale = config.scale;
+ let name = config.name;
+ let version = config.version;
Ok(Self {
engine,
@@ -135,7 +128,6 @@ impl BaseModelTextual {
dtype,
task,
scale,
- kind,
device,
version,
name,
diff --git a/src/models/pipeline/image_classifier.rs b/src/models/pipeline/image_classifier.rs
index 0dafd43..e7a9b12 100644
--- a/src/models/pipeline/image_classifier.rs
+++ b/src/models/pipeline/image_classifier.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::Axis;
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Image, Options, Prob, Processor, Ts, Xs, Y};
+use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Prob, Processor, Ts, Xs, Y};
#[derive(Debug, Builder)]
pub struct ImageClassifier {
@@ -20,11 +20,12 @@ pub struct ImageClassifier {
spec: String,
}
-impl TryFrom for ImageClassifier {
+impl TryFrom for ImageClassifier {
type Error = anyhow::Error;
- fn try_from(options: Options) -> Result {
- let engine = options.to_engine()?;
+ fn try_from(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
+
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -32,11 +33,8 @@ impl TryFrom for ImageClassifier {
engine.try_width().unwrap_or(&224.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let (nc, names) = match (options.nc(), options.class_names()) {
+
+ let (nc, names) = match (config.nc(), config.class_names()) {
(Some(nc), Some(names)) => {
if nc != names.len() {
anyhow::bail!(
@@ -56,8 +54,11 @@ impl TryFrom for ImageClassifier {
anyhow::bail!("Neither class names nor class numbers were specified.");
}
};
- let confs = DynConf::new(options.class_confs(), nc);
- let apply_softmax = options.apply_softmax.unwrap_or_default();
+ let confs = DynConf::new(config.class_confs(), nc);
+ let apply_softmax = config.apply_softmax.unwrap_or_default();
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
Ok(Self {
engine,
diff --git a/src/models/rfdetr/config.rs b/src/models/rfdetr/config.rs
index 53a861a..85f1220 100644
--- a/src/models/rfdetr/config.rs
+++ b/src/models/rfdetr/config.rs
@@ -1,18 +1,17 @@
use crate::NAMES_COCO_91;
/// Model configuration for `RT-DETR`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn rfdetr() -> Self {
Self::default()
- .with_model_name("rfdetr")
- .with_batch_size(1)
+ .with_name("rfdetr")
+ .with_model_ixx(0, 0, 1.into())
+ .with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 560.into())
.with_model_ixx(0, 3, 560.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
- .with_normalize(true)
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
- .with_class_confs(&[0.25])
.with_class_names(&NAMES_COCO_91)
}
diff --git a/src/models/rfdetr/impl.rs b/src/models/rfdetr/impl.rs
index dcac505..6d15e52 100644
--- a/src/models/rfdetr/impl.rs
+++ b/src/models/rfdetr/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, Y};
#[derive(Debug, Builder)]
pub struct RFDETR {
@@ -19,8 +19,8 @@ pub struct RFDETR {
}
impl RFDETR {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&560.into()).opt(),
@@ -28,16 +28,16 @@ impl RFDETR {
engine.ts.clone(),
);
let spec = engine.spec().to_owned();
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let names = options
+ let names: Vec = config
.class_names()
.expect("No class names specified.")
- .to_vec();
- let confs = DynConf::new(options.class_confs(), names.len());
-
+ .iter()
+ .map(|x| x.to_string())
+ .collect();
+ let confs = DynConf::new(config.class_confs(), names.len());
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
Ok(Self {
engine,
height,
diff --git a/src/models/rmbg/config.rs b/src/models/rmbg/config.rs
index 65ae2bb..2bb963d 100644
--- a/src/models/rmbg/config.rs
+++ b/src/models/rmbg/config.rs
@@ -1,9 +1,10 @@
/// Model configuration for `RMBG`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn rmbg() -> Self {
Self::default()
- .with_model_name("rmbg")
+ .with_name("rmbg")
.with_model_ixx(0, 0, 1.into())
+ .with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 1024.into())
.with_model_ixx(0, 3, 1024.into())
}
diff --git a/src/models/rmbg/impl.rs b/src/models/rmbg/impl.rs
index f886fa6..ea4d9c6 100644
--- a/src/models/rmbg/impl.rs
+++ b/src/models/rmbg/impl.rs
@@ -1,7 +1,7 @@
use aksr::Builder;
use anyhow::Result;
-use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct RMBG {
@@ -15,8 +15,8 @@ pub struct RMBG {
}
impl RMBG {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -24,8 +24,7 @@ impl RMBG {
engine.try_width().unwrap_or(&1024.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
@@ -63,7 +62,6 @@ impl RMBG {
fn postprocess(&mut self, xs: Xs) -> Result> {
let mut ys: Vec = Vec::new();
for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() {
- // image size
let (h1, w1) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
diff --git a/src/models/rtdetr/config.rs b/src/models/rtdetr/config.rs
index 2b2ebe9..56f8da4 100644
--- a/src/models/rtdetr/config.rs
+++ b/src/models/rtdetr/config.rs
@@ -1,15 +1,15 @@
use crate::NAMES_COCO_80;
/// Model configuration for `RT-DETR`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn rtdetr() -> Self {
Self::default()
- .with_model_name("rtdetr")
- .with_batch_size(1)
+ .with_name("rtdetr")
+ .with_model_ixx(0, 0, 1.into())
+ .with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 640.into())
.with_model_ixx(0, 3, 640.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
- .with_normalize(true)
.with_class_confs(&[0.5])
.with_class_names(&NAMES_COCO_80)
}
diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs
index ebef66b..53db03c 100644
--- a/src/models/rtdetr/impl.rs
+++ b/src/models/rtdetr/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Options, Processor, Ts, Xs, X, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, ModelConfig, Processor, Ts, Xs, X, Y};
#[derive(Debug, Builder)]
pub struct RTDETR {
@@ -19,8 +19,8 @@ pub struct RTDETR {
}
impl RTDETR {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&640.into()).opt(),
@@ -28,15 +28,16 @@ impl RTDETR {
engine.ts.clone(),
);
let spec = engine.spec().to_owned();
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let names = options
+ let names: Vec = config
.class_names()
.expect("No class names specified.")
- .to_vec();
- let confs = DynConf::new(options.class_confs(), names.len());
+ .iter()
+ .map(|x| x.to_string())
+ .collect();
+ let confs = DynConf::new(config.class_confs(), names.len());
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
Ok(Self {
engine,
@@ -87,7 +88,6 @@ impl RTDETR {
.enumerate()
.filter_map(|(idx, ((labels, boxes), scores))| {
let ratio = self.processor.images_transform_info[idx].height_scale;
-
let mut y_bboxes = Vec::new();
for (i, &score) in scores.iter().enumerate() {
let class_id = labels[i] as usize;
@@ -102,7 +102,6 @@ impl RTDETR {
xyxy[2] / ratio,
xyxy[3] / ratio,
);
-
y_bboxes.push(
Hbb::default()
.with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2)
diff --git a/src/models/rtmo/config.rs b/src/models/rtmo/config.rs
index d223269..43bcb24 100644
--- a/src/models/rtmo/config.rs
+++ b/src/models/rtmo/config.rs
@@ -1,9 +1,10 @@
/// Model configuration for `RTMO`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn rtmo() -> Self {
Self::default()
- .with_model_name("rtmo")
+ .with_name("rtmo")
.with_model_ixx(0, 0, 1.into())
+ .with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 640.into())
.with_model_ixx(0, 3, 640.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
diff --git a/src/models/rtmo/impl.rs b/src/models/rtmo/impl.rs
index 1edb944..0b42d30 100644
--- a/src/models/rtmo/impl.rs
+++ b/src/models/rtmo/impl.rs
@@ -2,7 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::Axis;
-use crate::{elapsed, DynConf, Engine, Hbb, Image, Keypoint, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, DynConf, Engine, Hbb, Image, Keypoint, ModelConfig, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct RTMO {
@@ -18,8 +18,8 @@ pub struct RTMO {
}
impl RTMO {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -27,15 +27,14 @@ impl RTMO {
engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+
+ let nk = config.nk().unwrap_or(17);
+ let confs = DynConf::new(config.class_confs(), 1);
+ let kconfs = DynConf::new(config.keypoint_confs(), nk);
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let nk = options.nk().unwrap_or(17);
- let confs = DynConf::new(options.class_confs(), 1);
- let kconfs = DynConf::new(options.keypoint_confs(), nk);
-
Ok(Self {
engine,
height,
diff --git a/src/models/sam/config.rs b/src/models/sam/config.rs
index 0e9ce58..f46e86c 100644
--- a/src/models/sam/config.rs
+++ b/src/models/sam/config.rs
@@ -1,100 +1,73 @@
-use crate::{models::SamKind, Options};
+use crate::{models::SamKind, ModelConfig};
/// Model configuration for `Segment Anything Model`
-impl Options {
+impl ModelConfig {
pub fn sam() -> Self {
Self::default()
- .with_model_name("sam")
- .with_model_ixx(0, 0, 1.into())
- }
-
- pub fn sam_encoder() -> Self {
- Self::sam()
- .with_model_ixx(0, 2, 1024.into())
- .with_model_ixx(0, 3, 1024.into())
+ .with_name("sam")
+ .with_encoder_ixx(0, 0, 1.into())
+ .with_encoder_ixx(0, 1, 3.into())
+ .with_encoder_ixx(0, 2, 1024.into())
+ .with_encoder_ixx(0, 3, 1024.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
.with_resize_filter("Bilinear")
.with_image_mean(&[123.5, 116.5, 103.5])
.with_image_std(&[58.5, 57.0, 57.5])
.with_normalize(false)
.with_sam_kind(SamKind::Sam)
- .with_low_res_mask(false)
+ .with_sam_low_res_mask(false)
.with_find_contours(true)
}
- pub fn sam_decoder() -> Self {
+ pub fn sam_v1_base() -> Self {
Self::sam()
+ .with_encoder_file("sam-vit-b-encoder.onnx")
+ .with_decoder_file("sam-vit-b-decoder.onnx")
}
- pub fn sam_v1_base_encoder() -> Self {
- Self::sam_encoder().with_model_file("sam-vit-b-encoder.onnx")
+ // pub fn sam_v1_base_singlemask_decoder() -> Self {
+ // Self::sam().with_decoder_file("sam-vit-b-decoder-singlemask.onnx")
+ // }
+
+ pub fn sam2_tiny() -> Self {
+ Self::sam()
+ .with_encoder_file("sam2-hiera-tiny-encoder.onnx")
+ .with_sam_kind(SamKind::Sam2)
+ .with_decoder_file("sam2-hiera-tiny-decoder.onnx")
}
- pub fn sam_v1_base_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam-vit-b-decoder.onnx")
- }
-
- pub fn sam_v1_base_singlemask_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam-vit-b-decoder-singlemask.onnx")
- }
-
- pub fn sam2_tiny_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("sam2-hiera-tiny-encoder.onnx")
+ pub fn sam2_small() -> Self {
+ Self::sam()
+ .with_encoder_file("sam2-hiera-small-encoder.onnx")
+ .with_decoder_file("sam2-hiera-small-decoder.onnx")
.with_sam_kind(SamKind::Sam2)
}
- pub fn sam2_tiny_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam2-hiera-tiny-decoder.onnx")
- }
-
- pub fn sam2_small_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("sam2-hiera-small-encoder.onnx")
+ pub fn sam2_base_plus() -> Self {
+ Self::sam()
+ .with_encoder_file("sam2-hiera-base-plus-encoder.onnx")
+ .with_decoder_file("sam2-hiera-base-plus-decoder.onnx")
.with_sam_kind(SamKind::Sam2)
}
- pub fn sam2_small_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam2-hiera-small-decoder.onnx")
- }
-
- pub fn sam2_base_plus_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("sam2-hiera-base-plus-encoder.onnx")
- .with_sam_kind(SamKind::Sam2)
- }
-
- pub fn sam2_base_plus_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam2-hiera-base-plus-decoder.onnx")
- }
-
- pub fn mobile_sam_tiny_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("mobile-sam-vit-t-encoder.onnx")
+ pub fn mobile_sam_tiny() -> Self {
+ Self::sam()
+ .with_encoder_file("mobile-sam-vit-t-encoder.onnx")
.with_sam_kind(SamKind::MobileSam)
+ .with_decoder_file("mobile-sam-vit-t-decoder.onnx")
}
- pub fn mobile_sam_tiny_decoder() -> Self {
- Self::sam_decoder().with_model_file("mobile-sam-vit-t-decoder.onnx")
- }
-
- pub fn sam_hq_tiny_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("sam-hq-vit-t-encoder.onnx")
+ pub fn sam_hq_tiny() -> Self {
+ Self::sam()
+ .with_encoder_file("sam-hq-vit-t-encoder.onnx")
.with_sam_kind(SamKind::SamHq)
+ .with_decoder_file("sam-hq-vit-t-decoder.onnx")
}
- pub fn sam_hq_tiny_decoder() -> Self {
- Self::sam_decoder().with_model_file("sam-hq-vit-t-decoder.onnx")
- }
-
- pub fn edge_sam_3x_encoder() -> Self {
- Self::sam_encoder()
- .with_model_file("edge-sam-3x-encoder.onnx")
+ pub fn edge_sam_3x() -> Self {
+ Self::sam()
+ .with_encoder_file("edge-sam-3x-encoder.onnx")
+ .with_decoder_file("edge-sam-3x-decoder.onnx")
.with_sam_kind(SamKind::EdgeSam)
}
-
- pub fn edge_sam_3x_decoder() -> Self {
- Self::sam_decoder().with_model_file("edge-sam-3x-decoder.onnx")
- }
}
diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs
index 0e02ca5..fea747d 100644
--- a/src/models/sam/impl.rs
+++ b/src/models/sam/impl.rs
@@ -4,8 +4,8 @@ use ndarray::{s, Axis};
use rand::prelude::*;
use crate::{
- elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X,
- Y,
+ elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, SamPrompt, Ts, Xs,
+ X, Y,
};
#[derive(Debug, Clone)]
@@ -49,9 +49,10 @@ pub struct SAM {
}
impl SAM {
- pub fn new(options_encoder: Options, options_decoder: Options) -> Result {
- let encoder = options_encoder.to_engine()?;
- let decoder = options_decoder.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let encoder = Engine::try_from_config(&config.encoder)?;
+ let decoder = Engine::try_from_config(&config.decoder)?;
+
let (batch, height, width) = (
encoder.batch().opt(),
encoder.try_height().unwrap_or(&1024.into()).opt(),
@@ -60,24 +61,23 @@ impl SAM {
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
let spec = encoder.spec().to_owned();
- let processor = options_encoder
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
-
- let conf = DynConf::new(options_encoder.class_confs(), 1);
- let find_contours = options_encoder.find_contours;
- let kind = match options_encoder.sam_kind {
+ let conf = DynConf::new(config.class_confs(), 1);
+ let find_contours = config.find_contours;
+ let kind = match config.sam_kind {
Some(x) => x,
None => anyhow::bail!("Error: no clear `SamKind` specified."),
};
let use_low_res_mask = match kind {
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
- options_encoder.low_res_mask.unwrap_or(false)
+ config.sam_low_res_mask.unwrap_or(false)
}
SamKind::EdgeSam | SamKind::Sam2 => true,
};
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
Ok(Self {
encoder,
decoder,
diff --git a/src/models/sam2/config.rs b/src/models/sam2/config.rs
index db9df28..f58f7a7 100644
--- a/src/models/sam2/config.rs
+++ b/src/models/sam2/config.rs
@@ -1,50 +1,28 @@
-use crate::Options;
+use crate::ModelConfig;
/// Model configuration for `SAM2.1`
-impl Options {
- pub fn sam2_encoder() -> Self {
+impl ModelConfig {
+ pub fn sam2_1_tiny() -> Self {
Self::sam()
- .with_model_ixx(0, 2, 1024.into())
- .with_model_ixx(0, 3, 1024.into())
- .with_resize_mode(crate::ResizeMode::FitAdaptive)
- .with_resize_filter("Bilinear")
- .with_image_mean(&[0.485, 0.456, 0.406])
- .with_image_std(&[0.229, 0.224, 0.225])
+ .with_encoder_file("sam2.1-hiera-tiny-encoder.onnx")
+ .with_decoder_file("sam2.1-hiera-tiny-decoder.onnx")
}
- pub fn sam2_decoder() -> Self {
+ pub fn sam2_1_small() -> Self {
Self::sam()
+ .with_encoder_file("sam2.1-hiera-small-encoder.onnx")
+ .with_decoder_file("sam2.1-hiera-small-decoder.onnx")
}
- pub fn sam2_1_tiny_encoder() -> Self {
- Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx")
+ pub fn sam2_1_base_plus() -> Self {
+ Self::sam()
+ .with_encoder_file("sam2.1-hiera-base-plus-encoder.onnx")
+ .with_decoder_file("sam2.1-hiera-base-plus-decoder.onnx")
}
- pub fn sam2_1_tiny_decoder() -> Self {
- Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx")
- }
-
- pub fn sam2_1_small_encoder() -> Self {
- Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx")
- }
-
- pub fn sam2_1_small_decoder() -> Self {
- Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx")
- }
-
- pub fn sam2_1_base_plus_encoder() -> Self {
- Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx")
- }
-
- pub fn sam2_1_base_plus_decoder() -> Self {
- Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx")
- }
-
- pub fn sam2_1_large_encoder() -> Self {
- Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx")
- }
-
- pub fn sam2_1_large_decoder() -> Self {
- Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx")
+ pub fn sam2_1_large() -> Self {
+ Self::sam()
+ .with_encoder_file("sam2.1-hiera-large-encoder.onnx")
+ .with_decoder_file("sam2.1-hiera-large-decoder.onnx")
}
}
diff --git a/src/models/sam2/impl.rs b/src/models/sam2/impl.rs
index cb32dce..9f576dc 100644
--- a/src/models/sam2/impl.rs
+++ b/src/models/sam2/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::{s, Axis};
use crate::{
- elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y,
+ elapsed, DynConf, Engine, Image, Mask, ModelConfig, Ops, Processor, SamPrompt, Ts, Xs, X, Y,
};
#[derive(Builder, Debug)]
@@ -20,9 +20,9 @@ pub struct SAM2 {
}
impl SAM2 {
- pub fn new(options_encoder: Options, options_decoder: Options) -> Result {
- let encoder = options_encoder.to_engine()?;
- let decoder = options_decoder.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let encoder = Engine::try_from_config(&config.encoder)?;
+ let decoder = Engine::try_from_config(&config.decoder)?;
let (batch, height, width) = (
encoder.batch().opt(),
encoder.try_height().unwrap_or(&1024.into()).opt(),
@@ -30,11 +30,11 @@ impl SAM2 {
);
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
let spec = encoder.spec().to_owned();
- let processor = options_encoder
- .to_processor()?
+
+ let conf = DynConf::new(config.class_confs(), 1);
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let conf = DynConf::new(options_encoder.class_confs(), 1);
Ok(Self {
encoder,
diff --git a/src/models/sapiens/config.rs b/src/models/sapiens/config.rs
index 053cf07..a51fbc7 100644
--- a/src/models/sapiens/config.rs
+++ b/src/models/sapiens/config.rs
@@ -1,15 +1,14 @@
use crate::NAMES_BODY_PARTS_28;
/// Model configuration for `Sapiens`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn sapiens() -> Self {
Self::default()
- .with_model_name("sapiens")
+ .with_name("sapiens")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 2, 1024.into())
.with_model_ixx(0, 3, 768.into())
.with_resize_mode(crate::ResizeMode::FitExact)
- .with_resize_filter("Bilinear")
.with_image_mean(&[123.5, 116.5, 103.5])
.with_image_std(&[58.5, 57.0, 57.5])
.with_normalize(false)
@@ -17,31 +16,11 @@ impl crate::Options {
pub fn sapiens_body_part_segmentation() -> Self {
Self::sapiens()
- .with_model_task(crate::Task::InstanceSegmentation)
+ .with_task(crate::Task::InstanceSegmentation)
.with_class_names(&NAMES_BODY_PARTS_28)
}
pub fn sapiens_seg_0_3b() -> Self {
Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b.onnx")
}
-
- // pub fn sapiens_seg_0_3b_uint8() -> Self {
- // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-uint8.onnx")
- // }
-
- // pub fn sapiens_seg_0_3b_fp16() -> Self {
- // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-fp16.onnx")
- // }
-
- // pub fn sapiens_seg_0_3b_bnb4() -> Self {
- // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-bnb4.onnx")
- // }
-
- // pub fn sapiens_seg_0_3b_q4f16() -> Self {
- // Self::sapiens_body_part_segmentation().with_model_file("seg-0.3b-q4f16.onnx")
- // }
-
- // pub fn sapiens_seg_0_6b_fp16() -> Self {
- // Self::sapiens_body_part_segmentation().with_model_file("seg-0.6b-fp16.onnx")
- // }
}
diff --git a/src/models/sapiens/impl.rs b/src/models/sapiens/impl.rs
index 65ea6fd..a192bd5 100644
--- a/src/models/sapiens/impl.rs
+++ b/src/models/sapiens/impl.rs
@@ -2,7 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Array2, Axis};
-use crate::{elapsed, Engine, Image, Mask, Ops, Options, Polygon, Processor, Task, Ts, Xs, Y};
+use crate::{elapsed, Engine, Image, Mask, ModelConfig, Ops, Polygon, Processor, Task, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct Sapiens {
@@ -18,8 +18,8 @@ pub struct Sapiens {
}
impl Sapiens {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -27,12 +27,12 @@ impl Sapiens {
engine.try_width().unwrap_or(&768.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+
+ let task = config.task.expect("No sapiens task specified.");
+ let names_body = config.class_names;
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let task = options.model_task.expect("No sapiens task specified.");
- let names_body = options.class_names;
Ok(Self {
engine,
diff --git a/src/models/slanet/config.rs b/src/models/slanet/config.rs
index f29b311..d045fdc 100644
--- a/src/models/slanet/config.rs
+++ b/src/models/slanet/config.rs
@@ -1,14 +1,14 @@
/// Model configuration for `SLANet`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn slanet() -> Self {
Self::default()
- .with_model_name("slanet")
+ .with_name("slanet")
.with_model_ixx(0, 0, (1, 1, 8).into())
+ .with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, (320, 488, 488).into())
.with_model_ixx(0, 3, (320, 488, 488).into())
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
- .with_normalize(true)
.with_resize_mode(crate::ResizeMode::FitAdaptive)
.with_padding_value(0)
.with_unsigned(true)
@@ -17,6 +17,6 @@ impl crate::Options {
pub fn slanet_lcnet_v2_mobile_ch() -> Self {
Self::slanet()
.with_model_file("v2-mobile-ch.onnx")
- .with_vocab_txt("vocab-sla-v2.txt")
+ .with_vocab_txt("slanet/vocab-sla-v2.txt")
}
}
diff --git a/src/models/slanet/impl.rs b/src/models/slanet/impl.rs
index fc2524e..54da596 100644
--- a/src/models/slanet/impl.rs
+++ b/src/models/slanet/impl.rs
@@ -2,7 +2,7 @@ use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
-use crate::{elapsed, models::BaseModelVisual, Image, Keypoint, Options, Ts, Xs, Y};
+use crate::{elapsed, models::BaseModelVisual, Image, Keypoint, ModelConfig, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct SLANet {
@@ -19,8 +19,8 @@ impl SLANet {
self.ts.summary();
}
- pub fn new(options: Options) -> Result {
- let base = BaseModelVisual::new(options)?;
+ pub fn new(config: ModelConfig) -> Result {
+ let base = BaseModelVisual::new(config)?;
let spec = base.engine().spec().to_owned();
let sos = 0;
let eos = base.processor().vocab().len() - 1;
diff --git a/src/models/smolvlm/config.rs b/src/models/smolvlm/config.rs
index 1aa4c9c..339f115 100644
--- a/src/models/smolvlm/config.rs
+++ b/src/models/smolvlm/config.rs
@@ -1,58 +1,28 @@
/// Model configuration for `SmolVLM`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn smolvlm() -> Self {
Self::default()
- .with_batch_size(1)
- .with_model_name("smolvlm")
- .with_model_num_dry_run(3)
- }
-
- pub fn smolvlm_vision() -> Self {
- Self::smolvlm()
- .with_model_kind(crate::Kind::Vision)
+ .with_name("smolvlm")
+ .with_batch_size_all(1)
.with_image_mean(&[0.5, 0.5, 0.5])
.with_image_std(&[0.5, 0.5, 0.5])
.with_resize_filter("lanczos3")
- .with_normalize(true)
+ .with_tokenizer_file("smolvlm/tokenizer.json")
}
- pub fn smolvlm_text() -> Self {
- Self::smolvlm().with_model_kind(crate::Kind::Language)
+ pub fn smolvlm_256m() -> Self {
+ Self::smolvlm()
+ .with_scale(crate::Scale::Million(256.))
+ .with_visual_file("256m-vision-encoder.onnx")
+ .with_textual_file("256m-embed-tokens.onnx")
+ .with_textual_decoder_file("256m-decoder-model-merged.onnx")
}
- pub fn smolvlm_vision_256m() -> Self {
- Self::smolvlm_vision()
- .with_model_scale(crate::Scale::Million(256.))
- .with_model_file("256m-vision-encoder.onnx")
- }
-
- pub fn smolvlm_text_embed_256m() -> Self {
- Self::smolvlm_text()
- .with_model_scale(crate::Scale::Million(256.))
- .with_model_file("256m-embed-tokens.onnx")
- }
-
- pub fn smolvlm_decoder_256m() -> Self {
- Self::smolvlm_text()
- .with_model_scale(crate::Scale::Million(256.))
- .with_model_file("256m-decoder-model-merged.onnx")
- }
-
- pub fn smolvlm_vision_500m() -> Self {
- Self::smolvlm_vision()
- .with_model_scale(crate::Scale::Million(500.))
- .with_model_file("500m-vision-encoder.onnx")
- }
-
- pub fn smolvlm_text_embed_500m() -> Self {
- Self::smolvlm_text()
- .with_model_scale(crate::Scale::Million(500.))
- .with_model_file("500m-embed-tokens.onnx")
- }
-
- pub fn smolvlm_decoder_500m() -> Self {
- Self::smolvlm_text()
- .with_model_scale(crate::Scale::Million(500.))
- .with_model_file("500m-decoder-model-merged.onnx")
+ pub fn smolvlm_500m() -> Self {
+ Self::smolvlm()
+ .with_scale(crate::Scale::Million(500.))
+ .with_visual_file("500m-vision-encoder.onnx")
+ .with_textual_file("500m-embed-tokens.onnx")
+ .with_textual_decoder_file("500m-decoder-model-merged.onnx")
}
}
diff --git a/src/models/smolvlm/impl.rs b/src/models/smolvlm/impl.rs
index 9cb09e8..48cf3cc 100644
--- a/src/models/smolvlm/impl.rs
+++ b/src/models/smolvlm/impl.rs
@@ -3,15 +3,13 @@ use anyhow::Result;
use image::GenericImageView;
use ndarray::s;
-use crate::{
- models::BaseModelTextual, Engine, Image, LogitsSampler, Options, Processor, Scale, Ts, Xs, X, Y,
-};
+use crate::{Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y};
#[derive(Debug, Builder)]
pub struct SmolVLM {
- vision: VisionEncoder,
- text_embed: BaseModelTextual,
- decoder: BaseModelTextual,
+ vision: Engine,
+ text_embed: Engine,
+ decoder: Engine,
scale: Scale,
image_token: String,
global_img_token: String,
@@ -25,17 +23,20 @@ pub struct SmolVLM {
num_hidden_layers: usize,
head_dim: usize,
num_key_value_heads: usize,
+ num_patch: usize,
+ batch: usize,
+ width: usize,
+ height: usize,
+ processor: Processor,
+ ts: Ts,
}
impl SmolVLM {
- pub fn new(
- options_vision_encoder: Options,
- options_text_embed: Options,
- options_decode: Options,
- ) -> Result {
- let vision = VisionEncoder::new(options_vision_encoder)?;
- let text_embed = BaseModelTextual::new(options_text_embed)?;
- let decoder = BaseModelTextual::new(options_decode)?;
+ pub fn new(config: ModelConfig) -> Result {
+ let vision = Engine::try_from_config(&config.visual)?;
+ let text_embed = Engine::try_from_config(&config.textual)?;
+ let decoder = Engine::try_from_config(&config.textual_decoder_merged)?;
+
let fake_image_token = "".to_string();
let image_token = "".to_string();
let global_img_token = "".to_string();
@@ -45,12 +46,23 @@ impl SmolVLM {
let image_token_id = 49190;
let image_seq_len = 64;
let max_length = 1024;
- let (num_hidden_layers, head_dim, num_key_value_heads) = match decoder.scale() {
+ let (num_hidden_layers, head_dim, num_key_value_heads) = match &config.scale {
Some(Scale::Million(256.)) => (30, 64, 3),
Some(Scale::Million(500.)) => (32, 64, 5),
_ => unimplemented!(),
};
- let scale = (*decoder.scale().unwrap()).clone();
+ let scale = config.scale.clone().unwrap();
+
+ let (batch, num_patch, height, width, ts) = (
+ vision.batch().opt(),
+ vision.inputs_minoptmax()[0][1].opt(),
+ vision.inputs_minoptmax()[0][3].opt(),
+ vision.inputs_minoptmax()[0][4].opt(),
+ vision.ts.clone(),
+ );
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
Ok(Self {
vision,
@@ -69,6 +81,12 @@ impl SmolVLM {
bos_token,
eos_token,
image_seq_len,
+ batch,
+ num_patch,
+ height,
+ width,
+ ts,
+ processor,
})
}
@@ -86,13 +104,13 @@ impl SmolVLM {
let bs = 1; // TODO
// patches and pixel_attention_mask
- let (patches, nw_nh) = self.vision.process_one(image)?;
+ let (patches, nw_nh) = self.process_one(image)?;
let dims = patches.dims();
let pixel_attention_mask = X::ones(&[dims[0], dims[1], dims[3], dims[4]]);
// input ids
let prompt = self.image_prompt_string(nw_nh, text);
- let mut input_ids: Vec = self.text_embed.processor().encode_text_ids(&prompt, true)?;
+ let mut input_ids: Vec = self.processor.encode_text_ids(&prompt, true)?;
// position ids
let mut position_ids = X::from(
@@ -114,12 +132,11 @@ impl SmolVLM {
for ii in 0..self.max_length {
// inputs embeds
let input_ids_x = X::from(input_ids.clone()).insert_axis(0)?;
- let mut inputs_embeds =
- self.text_embed.inference(input_ids_x.clone().into())?[0].clone();
+ let mut inputs_embeds = self.text_embed.run(input_ids_x.clone().into())?[0].clone();
// encode image and merge
if ii == 0 {
- let image_features = self.vision.inference(Xs::from(vec![
+ let image_features = self.vision.run(Xs::from(vec![
patches.clone(),
pixel_attention_mask.clone(),
]))?[0]
@@ -152,7 +169,7 @@ impl SmolVLM {
}
// decode
- let decoder_outputs = self.decoder.inference(xs.into())?;
+ let decoder_outputs = self.decoder.run(xs.into())?;
let logits = &decoder_outputs[0];
past_key_values = (1..decoder_outputs.len())
.step_by(2)
@@ -186,10 +203,7 @@ impl SmolVLM {
}
// decode tokens
- let text = self
- .text_embed
- .processor()
- .decode_tokens(&token_ids, true)?;
+ let text = self.processor.decode_tokens(&token_ids, true)?;
Ok(text)
}
@@ -233,44 +247,6 @@ impl SmolVLM {
}
}
}
-}
-
-#[derive(Debug, Builder)]
-pub struct VisionEncoder {
- engine: Engine,
- num_patch: usize,
- batch: usize,
- width: usize,
- height: usize,
- processor: Processor,
- ts: Ts,
-}
-
-impl VisionEncoder {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
- let (batch, num_patch, height, width, ts) = (
- engine.batch().opt(),
- engine.inputs_minoptmax()[0][1].opt(),
- engine.inputs_minoptmax()[0][3].opt(),
- engine.inputs_minoptmax()[0][4].opt(),
- engine.ts.clone(),
- );
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
-
- Ok(Self {
- engine,
- num_patch,
- batch,
- width,
- height,
- processor,
- ts,
- })
- }
fn create_patches(image: &Image, patch_size: (u32, u32)) -> (Vec, (u32, u32)) {
let mut patches = vec![];
@@ -307,10 +283,6 @@ impl VisionEncoder {
(patches, (nw, nh))
}
- pub fn inference(&mut self, xs: Xs) -> Result {
- self.engine.run(xs)
- }
-
pub fn process_one(&mut self, x: &Image) -> Result<(X, (u32, u32))> {
let (patches, nw_nh) = Self::create_patches(x, (self.width as _, self.height as _));
let patches = self.processor.process_images(&patches)?.insert_axis(0)?;
diff --git a/src/models/svtr/config.rs b/src/models/svtr/config.rs
index 8f0c06e..583b7f5 100644
--- a/src/models/svtr/config.rs
+++ b/src/models/svtr/config.rs
@@ -1,8 +1,8 @@
/// Model configuration for `SVTR`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn svtr() -> Self {
Self::default()
- .with_model_name("svtr")
+ .with_name("svtr")
.with_model_ixx(0, 0, (1, 1, 8).into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 48.into())
@@ -14,11 +14,11 @@ impl crate::Options {
}
pub fn svtr_ch() -> Self {
- Self::svtr().with_vocab_txt("vocab-v1-ppocr-rec-ch.txt")
+ Self::svtr().with_vocab_txt("svtr/vocab-v1-ppocr-rec-ch.txt")
}
pub fn svtr_en() -> Self {
- Self::svtr().with_vocab_txt("vocab-v1-ppocr-rec-en.txt")
+ Self::svtr().with_vocab_txt("svtr/vocab-v1-ppocr-rec-en.txt")
}
pub fn ppocr_rec_v3_ch() -> Self {
diff --git a/src/models/svtr/impl.rs b/src/models/svtr/impl.rs
index eff6bda..728c7bb 100644
--- a/src/models/svtr/impl.rs
+++ b/src/models/svtr/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::Axis;
use rayon::prelude::*;
-use crate::{elapsed, DynConf, Engine, Image, Options, Processor, Ts, Xs, Y};
+use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct SVTR {
@@ -18,18 +18,17 @@ pub struct SVTR {
}
impl SVTR {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&960.into()).opt(),
engine.try_width().unwrap_or(&960.into()).opt(),
engine.ts.clone(),
);
- let spec = options.model_spec().to_string();
- let confs = DynConf::new(options.class_confs(), 1);
- let processor = options
- .to_processor()?
+ let spec = config.model.spec.to_string();
+ let confs = DynConf::new(config.class_confs(), 1);
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
if processor.vocab().is_empty() {
diff --git a/src/models/trocr/config.rs b/src/models/trocr/config.rs
index 8343434..291c8e6 100644
--- a/src/models/trocr/config.rs
+++ b/src/models/trocr/config.rs
@@ -1,92 +1,52 @@
use crate::Scale;
/// Model configuration for `TrOCR`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn trocr() -> Self {
- Self::default().with_model_name("trocr").with_batch_size(1)
- }
-
- pub fn trocr_visual() -> Self {
- Self::trocr()
- .with_model_kind(crate::Kind::Vision)
- .with_model_ixx(0, 1, 3.into())
- .with_model_ixx(0, 2, 384.into())
- .with_model_ixx(0, 3, 384.into())
+ Self::default()
+ .with_name("trocr")
+ .with_batch_size_all(1)
+ .with_visual_ixx(0, 1, 3.into())
+ .with_visual_ixx(0, 2, 384.into())
+ .with_visual_ixx(0, 3, 384.into())
.with_image_mean(&[0.5, 0.5, 0.5])
.with_image_std(&[0.5, 0.5, 0.5])
- .with_resize_filter("Bilinear")
- .with_normalize(true)
+ .with_resize_filter("lanczos3")
+ .with_tokenizer_file("trocr/tokenizer.json")
+ .with_config_file("trocr/config.json")
+ .with_special_tokens_map_file("trocr/special_tokens_map.json")
+ .with_tokenizer_config_file("trocr/tokenizer_config.json")
}
- pub fn trocr_textual() -> Self {
- Self::trocr().with_model_kind(crate::Kind::Language)
- }
-
- pub fn trocr_visual_small() -> Self {
- Self::trocr_visual().with_model_scale(Scale::S)
- }
-
- pub fn trocr_textual_small() -> Self {
- Self::trocr_textual()
- .with_model_scale(Scale::S)
+ pub fn trocr_small_printed() -> Self {
+ Self::trocr()
+ .with_scale(Scale::S)
+ .with_visual_file("s-encoder-printed.onnx")
+ .with_textual_decoder_file("s-decoder-printed.onnx")
+ .with_textual_decoder_merged_file("s-decoder-merged-printed.onnx")
.with_tokenizer_file("trocr/tokenizer-small.json")
}
- pub fn trocr_visual_base() -> Self {
- Self::trocr_visual().with_model_scale(Scale::B)
- }
-
- pub fn trocr_textual_base() -> Self {
- Self::trocr_textual()
- .with_model_scale(Scale::B)
+ pub fn trocr_base_handwritten() -> Self {
+ Self::trocr()
+ .with_scale(Scale::B)
+ .with_visual_file("b-encoder-handwritten.onnx")
+ .with_textual_decoder_file("b-decoder-handwritten.onnx")
+ .with_textual_decoder_merged_file("b-decoder-merged-handwritten.onnx")
.with_tokenizer_file("trocr/tokenizer-base.json")
}
- pub fn trocr_encoder_small_printed() -> Self {
- Self::trocr_visual_small().with_model_file("s-encoder-printed.onnx")
+ pub fn trocr_small_handwritten() -> Self {
+ Self::trocr_small_printed()
+ .with_visual_file("s-encoder-handwritten.onnx")
+ .with_textual_decoder_file("s-decoder-handwritten.onnx")
+ .with_textual_decoder_merged_file("s-decoder-merged-handwritten.onnx")
}
- pub fn trocr_decoder_small_printed() -> Self {
- Self::trocr_textual_small().with_model_file("s-decoder-printed.onnx")
- }
-
- pub fn trocr_decoder_merged_small_printed() -> Self {
- Self::trocr_textual_small().with_model_file("s-decoder-merged-printed.onnx")
- }
-
- pub fn trocr_encoder_small_handwritten() -> Self {
- Self::trocr_visual_small().with_model_file("s-encoder-handwritten.onnx")
- }
-
- pub fn trocr_decoder_small_handwritten() -> Self {
- Self::trocr_textual_small().with_model_file("s-decoder-handwritten.onnx")
- }
-
- pub fn trocr_decoder_merged_small_handwritten() -> Self {
- Self::trocr_textual_small().with_model_file("s-decoder-merged-handwritten.onnx")
- }
-
- pub fn trocr_encoder_base_printed() -> Self {
- Self::trocr_visual_base().with_model_file("b-encoder-printed.onnx")
- }
-
- pub fn trocr_decoder_base_printed() -> Self {
- Self::trocr_textual_base().with_model_file("b-decoder-printed.onnx")
- }
-
- pub fn trocr_decoder_merged_base_printed() -> Self {
- Self::trocr_textual_base().with_model_file("b-decoder-merged-printed.onnx")
- }
-
- pub fn trocr_encoder_base_handwritten() -> Self {
- Self::trocr_visual_base().with_model_file("b-encoder-handwritten.onnx")
- }
-
- pub fn trocr_decoder_base_handwritten() -> Self {
- Self::trocr_textual_base().with_model_file("b-decoder-handwritten.onnx")
- }
-
- pub fn trocr_decoder_merged_base_handwritten() -> Self {
- Self::trocr_textual_base().with_model_file("b-decoder-merged-handwritten.onnx")
+ pub fn trocr_base_printed() -> Self {
+ Self::trocr_base_handwritten()
+ .with_visual_file("b-encoder-printed.onnx")
+ .with_textual_decoder_file("b-decoder-printed.onnx")
+ .with_textual_decoder_merged_file("b-decoder-merged-printed.onnx")
}
}
diff --git a/src/models/trocr/impl.rs b/src/models/trocr/impl.rs
index a0759a7..e7a981e 100644
--- a/src/models/trocr/impl.rs
+++ b/src/models/trocr/impl.rs
@@ -3,11 +3,7 @@ use anyhow::Result;
use ndarray::{s, Axis};
use rayon::prelude::*;
-use crate::{
- elapsed,
- models::{BaseModelTextual, BaseModelVisual},
- Image, LogitsSampler, Options, Scale, Ts, Xs, X, Y,
-};
+use crate::{elapsed, Engine, Image, LogitsSampler, ModelConfig, Processor, Scale, Ts, Xs, X, Y};
#[derive(Debug, Copy, Clone)]
pub enum TrOCRKind {
@@ -29,35 +25,31 @@ impl TryFrom<&str> for TrOCRKind {
#[derive(Debug, Builder)]
pub struct TrOCR {
- encoder: BaseModelVisual,
- decoder: BaseModelTextual,
- decoder_merged: BaseModelTextual,
+ encoder: Engine,
+ decoder: Engine,
+ decoder_merged: Engine,
max_length: u32,
eos_token_id: u32,
decoder_start_token_id: u32,
ts: Ts,
n_kvs: usize,
+ processor: Processor,
+ batch: usize,
+ height: usize,
+ width: usize,
}
impl TrOCR {
- pub fn summary(&self) {
- self.ts.summary();
- }
-
- pub fn new(
- options_encoder: Options,
- options_decoder: Options,
- options_decoder_merged: Options,
- ) -> Result {
- let encoder = BaseModelVisual::new(options_encoder)?;
- let decoder = BaseModelTextual::new(options_decoder)?;
- let decoder_merged = BaseModelTextual::new(options_decoder_merged)?;
- let ts = Ts::merge(&[
- encoder.engine().ts(),
- decoder.engine().ts(),
- decoder_merged.engine().ts(),
- ]);
-
+ pub fn new(config: ModelConfig) -> Result {
+ let encoder = Engine::try_from_config(&config.visual)?;
+ let decoder = Engine::try_from_config(&config.textual_decoder)?;
+ let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?;
+ let (batch, height, width) = (
+ encoder.batch().opt(),
+ encoder.try_height().unwrap_or(&384.into()).opt(),
+ encoder.try_width().unwrap_or(&384.into()).opt(),
+ );
+ let ts = Ts::merge(&[encoder.ts(), decoder.ts(), decoder_merged.ts()]);
// "bos_token": "", "eos_token": "", "sep_token": "",
// "model_max_length": 1000000000000000019884624838656,
// let bos_token = "";
@@ -68,12 +60,16 @@ impl TrOCR {
let max_length = 1024; // TODO
let eos_token_id = 2;
let decoder_start_token_id = 2;
- let n_kvs = match decoder.scale() {
+ let n_kvs = match &config.scale {
Some(Scale::S) => 6,
Some(Scale::B) => 12,
_ => unimplemented!(),
};
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
Ok(Self {
encoder,
decoder,
@@ -83,11 +79,22 @@ impl TrOCR {
eos_token_id,
decoder_start_token_id,
n_kvs,
+ batch,
+ width,
+ height,
+ processor,
})
}
+ pub fn encode(&mut self, xs: &[Image]) -> Result {
+ let ys = self.processor.process_images(xs)?;
+ self.batch = xs.len(); // update
+ let ys = self.encoder.run(ys.into())?;
+ Ok(ys[0].to_owned())
+ }
+
pub fn forward(&mut self, xs: &[Image]) -> Result> {
- let encoder_hidden_states = elapsed!("encode", self.ts, { self.encoder.encode(xs)? });
+ let encoder_hidden_states = elapsed!("encode", self.ts, { self.encode(xs)? });
let generated = elapsed!("generate", self.ts, {
self.generate(&encoder_hidden_states)?
});
@@ -100,10 +107,10 @@ impl TrOCR {
// input_ids
let input_ids = X::from(vec![self.decoder_start_token_id as f32])
.insert_axis(0)?
- .repeat(0, self.encoder.batch())?;
+ .repeat(0, self.batch)?;
// decoder
- let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
+ let mut decoder_outputs = self.decoder.run(Xs::from(vec![
input_ids.clone(),
encoder_hidden_states.clone(),
]))?;
@@ -116,9 +123,9 @@ impl TrOCR {
.collect();
// token ids
- let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()];
- let mut finished = vec![false; self.encoder.batch()];
- let mut last_tokens: Vec = vec![0.; self.encoder.batch()];
+ let mut token_ids: Vec> = vec![vec![]; self.batch];
+ let mut finished = vec![false; self.batch];
+ let mut last_tokens: Vec = vec![0.; self.batch];
let logits_sampler = LogitsSampler::new();
// generate
@@ -169,78 +176,15 @@ impl TrOCR {
xs.push(X::ones(&[1])); // use_cache
// generate
- decoder_outputs = self.decoder_merged.inference(xs.into())?;
+ decoder_outputs = self.decoder_merged.run(xs.into())?;
}
Ok(token_ids)
}
- // fn generate(&mut self, encoder_hidden_states: &X) -> Result>> {
- // // input_ids
- // let input_ids = X::from(vec![self.decoder_start_token_id as f32])
- // .insert_axis(0)?
- // .repeat(0, self.encoder.batch())?;
-
- // // decoder
- // let mut decoder_outputs = self.decoder.inference(Xs::from(vec![
- // input_ids.clone(),
- // encoder_hidden_states.clone(),
- // ]))?;
-
- // // encoder kvs
- // let encoder_kvs: Vec<_> = (3..4 * self.n_kvs)
- // .step_by(4)
- // .flat_map(|i| [i, i + 1])
- // .map(|i| decoder_outputs[i].clone())
- // .collect();
-
- // // token ids
- // let mut token_ids: Vec> = vec![vec![]; self.encoder.batch()];
-
- // // generate
- // for _ in 0..self.max_length {
- // let logits = &decoder_outputs[0];
- // let decoder_kvs: Vec<_> = (1..(4 * self.n_kvs) - 2)
- // .step_by(4)
- // .flat_map(|i| [i, i + 1])
- // .map(|i| decoder_outputs[i].clone())
- // .collect();
-
- // // decode each token for each batch
- // let (finished, last_tokens) = self.decoder_merged.processor().par_generate(
- // logits,
- // &mut token_ids,
- // self.eos_token_id,
- // )?;
-
- // if finished {
- // break;
- // }
-
- // // build inputs
- // let input_ids = X::from(last_tokens).insert_axis(1)?;
- // let mut xs = vec![input_ids, encoder_hidden_states.clone()];
- // for i in 0..self.n_kvs {
- // xs.push(decoder_kvs[i * 2].clone());
- // xs.push(decoder_kvs[i * 2 + 1].clone());
- // xs.push(encoder_kvs[i * 2].clone());
- // xs.push(encoder_kvs[i * 2 + 1].clone());
- // }
- // xs.push(X::ones(&[1])); // use_cache
-
- // // generate
- // decoder_outputs = self.decoder_merged.inference(xs.into())?;
- // }
-
- // Ok(token_ids)
- // }
-
pub fn decode(&self, token_ids: Vec>) -> Result> {
// decode
- let texts = self
- .decoder_merged
- .processor()
- .decode_tokens_batch(&token_ids, false)?;
+ let texts = self.processor.decode_tokens_batch(&token_ids, false)?;
// to texts
let texts = texts
@@ -250,101 +194,8 @@ impl TrOCR {
Ok(texts)
}
+
+ pub fn summary(&self) {
+ self.ts.summary();
+ }
}
-
-// #[derive(Debug, Builder)]
-// pub struct TrOCREncoder {
-// // TODO: `BaseVisualEncoder`, `BaseVisualModel` struct?
-// engine: Engine,
-// height: usize,
-// width: usize,
-// batch: usize,
-// processor: Processor,
-// }
-
-// impl TrOCREncoder {
-// pub fn new(options: Options) -> Result {
-// let engine = options.to_engine()?;
-// let (batch, height, width) = (
-// engine.batch().opt(),
-// engine.try_height().unwrap_or(&384.into()).opt(),
-// engine.try_width().unwrap_or(&384.into()).opt(),
-// );
-// let processor = options
-// .to_processor()?
-// .with_image_width(width as _)
-// .with_image_height(height as _);
-
-// Ok(Self {
-// engine,
-// height,
-// width,
-// batch,
-// processor,
-// })
-// }
-
-// pub fn preprocess(&mut self, xs: &[DynamicImage]) -> Result {
-// self.batch = xs.len(); // TODO
-// let x = self.processor.process_images(xs)?;
-
-// Ok(x.into())
-// }
-
-// pub fn inference(&mut self, xs: Xs) -> Result {
-// self.engine.run(xs)
-// }
-
-// fn encode(&mut self, xs: &[DynamicImage]) -> Result {
-// // encode a batch of images into one embedding, that's `X`
-// let xs = self.preprocess(xs)?;
-// let xs = self.inference(xs)?;
-// let x = xs[0].to_owned();
-
-// Ok(x)
-// }
-// }
-
-// #[derive(Debug, Builder)]
-// pub struct TrOCRDecoder {
-// engine: Engine,
-// batch: usize,
-// }
-
-// impl TrOCRDecoder {
-// pub fn new(options: Options) -> Result {
-// let engine = options.to_engine()?;
-// let batch = engine.batch().opt();
-
-// Ok(Self { engine, batch })
-// }
-
-// pub fn inference(&mut self, xs: Xs) -> Result {
-// self.engine.run(xs)
-// }
-// }
-
-// #[derive(Debug, Builder)]
-// pub struct TrOCRDecoderMerged {
-// engine: Engine,
-// batch: usize,
-// processor: Processor,
-// }
-
-// impl TrOCRDecoderMerged {
-// pub fn new(options: Options) -> Result {
-// let engine = options.to_engine()?;
-// let batch = engine.batch().opt();
-// let processor = options.to_processor()?;
-
-// Ok(Self {
-// engine,
-// batch,
-// processor,
-// })
-// }
-
-// pub fn inference(&mut self, xs: Xs) -> Result {
-// self.engine.run(xs)
-// }
-// }
diff --git a/src/models/yolo/config.rs b/src/models/yolo/config.rs
index 8034827..7941d03 100644
--- a/src/models/yolo/config.rs
+++ b/src/models/yolo/config.rs
@@ -1,234 +1,134 @@
use crate::{
- models::YOLOPredsFormat, Options, ResizeMode, Scale, Task, NAMES_COCO_KEYPOINTS_17,
- NAMES_YOLO_DOCLAYOUT_10,
+ models::YOLOPredsFormat, ModelConfig, ResizeMode, Scale, Task, NAMES_COCO_80,
+ NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, NAMES_YOLO_DOCLAYOUT_10,
};
-impl Options {
+impl ModelConfig {
pub fn yolo() -> Self {
Self::default()
- .with_model_name("yolo")
+ .with_name("yolo")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 1, 3.into())
.with_model_ixx(0, 2, 640.into())
.with_model_ixx(0, 3, 640.into())
.with_resize_mode(ResizeMode::FitAdaptive)
.with_resize_filter("CatmullRom")
- .with_find_contours(true)
- }
-
- pub fn doclayout_yolo_docstructbench() -> Self {
- Self::yolo_v10()
- .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1
- .with_model_ixx(0, 2, (640, 1024, 1024).into())
- .with_model_ixx(0, 3, (640, 1024, 1024).into())
- .with_class_confs(&[0.4])
- .with_class_names(&NAMES_YOLO_DOCLAYOUT_10)
+ .with_class_names(&NAMES_COCO_80)
}
pub fn yolo_classify() -> Self {
Self::yolo()
- .with_model_task(Task::ImageClassification)
+ .with_task(Task::ImageClassification)
.with_model_ixx(0, 2, 224.into())
.with_model_ixx(0, 3, 224.into())
.with_resize_mode(ResizeMode::FitExact)
.with_resize_filter("Bilinear")
+ .with_class_names(&NAMES_IMAGENET_1K)
}
pub fn yolo_detect() -> Self {
- Self::yolo().with_model_task(Task::ObjectDetection)
+ Self::yolo().with_task(Task::ObjectDetection)
}
pub fn yolo_pose() -> Self {
Self::yolo()
- .with_model_task(Task::KeypointsDetection)
+ .with_task(Task::KeypointsDetection)
.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17)
}
pub fn yolo_segment() -> Self {
- Self::yolo().with_model_task(Task::InstanceSegmentation)
+ Self::yolo().with_task(Task::InstanceSegmentation)
}
pub fn yolo_obb() -> Self {
- Self::yolo().with_model_task(Task::OrientedObjectDetection)
+ Self::yolo().with_task(Task::OrientedObjectDetection)
}
- pub fn fastsam_s() -> Self {
- Self::yolo_segment()
- .with_model_scale(Scale::S)
- .with_model_version(8.into())
- .with_model_file("FastSAM-s.onnx")
+ pub fn auto_yolo_model_file(mut self) -> Self {
+ if self.model.file.is_empty() {
+ // [version]-[scale]-[task]
+ let mut y = String::new();
+ if let Some(x) = self.version() {
+ y.push_str(&x.to_string());
+ }
+ if let Some(x) = self.scale() {
+ y.push_str(&format!("-{}", x));
+ }
+ if let Some(x) = self.task() {
+ y.push_str(&format!("-{}", x.yolo_str()));
+ }
+ y.push_str(".onnx");
+ self.model.file = y;
+ }
+
+ self
}
- pub fn yolo_v8_rtdetr() -> Self {
- Self::yolo()
- .with_model_version(7.into())
- .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
- }
-
- pub fn yolo_v8_rtdetr_l() -> Self {
- Self::yolo_v8_rtdetr()
- .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
- .with_model_scale(Scale::L)
- .with_model_file("rtdetr-l-det.onnx")
- }
-
- pub fn yolo_v8_rtdetr_x() -> Self {
- Self::yolo_v8_rtdetr()
- .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
- .with_model_scale(Scale::X)
- }
-
- pub fn yolo_n() -> Self {
- Self::yolo().with_model_scale(Scale::N)
- }
-
- pub fn yolo_s() -> Self {
- Self::yolo().with_model_scale(Scale::S)
- }
-
- pub fn yolo_m() -> Self {
- Self::yolo().with_model_scale(Scale::M)
- }
-
- pub fn yolo_l() -> Self {
- Self::yolo().with_model_scale(Scale::L)
- }
-
- pub fn yolo_x() -> Self {
- Self::yolo().with_model_scale(Scale::X)
- }
-
- pub fn yolo_v5() -> Self {
- Self::yolo().with_model_version(5.into())
- }
-
- pub fn yolo_v6() -> Self {
- Self::yolo().with_model_version(6.into())
- }
-
- pub fn yolo_v7() -> Self {
- Self::yolo().with_model_version(7.into())
- }
-
- pub fn yolo_v8() -> Self {
- Self::yolo().with_model_version(8.into())
- }
-
- pub fn yolo_v9() -> Self {
- Self::yolo().with_model_version(9.into())
- }
-
- pub fn yolo_v10() -> Self {
- Self::yolo().with_model_version(10.into())
- }
-
- pub fn yolo_v11() -> Self {
- Self::yolo().with_model_version(11.into())
- }
-
- pub fn yolo_v12() -> Self {
- Self::yolo().with_model_version(12.into())
- }
-
- pub fn yolo_v8_n() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::N)
- }
-
- pub fn yolo_v8_s() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::S)
- }
-
- pub fn yolo_v8_m() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::M)
- }
-
- pub fn yolo_v8_l() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::L)
- }
-
- pub fn yolo_v8_x() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::X)
- }
-
- pub fn yolo_v11_n() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::N)
- }
-
- pub fn yolo_v11_s() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::S)
- }
-
- pub fn yolo_v11_m() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::M)
- }
-
- pub fn yolo_v11_l() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::L)
- }
-
- pub fn yolo_v11_x() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::X)
+ pub fn doclayout_yolo_docstructbench() -> Self {
+ Self::yolo_detect()
+ .with_version(10.into())
+ .with_model_ixx(0, 2, (640, 1024, 1024).into())
+ .with_model_ixx(0, 3, (640, 1024, 1024).into())
+ .with_class_confs(&[0.4])
+ .with_class_names(&NAMES_YOLO_DOCLAYOUT_10)
+ .with_model_file("doclayout-docstructbench.onnx") // TODO: batch_size > 1
}
+ // YOLOE models
pub fn yoloe_v8s_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::S)
+ Self::yolo_segment()
+ .with_version(8.into())
+ .with_scale(Scale::S)
.with_model_file("yoloe-v8s-seg-pf.onnx")
}
pub fn yoloe_v8m_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::M)
+ Self::yolo_segment()
+ .with_version(8.into())
+ .with_scale(Scale::M)
.with_model_file("yoloe-v8m-seg-pf.onnx")
}
pub fn yoloe_v8l_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(8.into())
- .with_model_scale(Scale::L)
+ Self::yolo_segment()
+ .with_version(8.into())
+ .with_scale(Scale::L)
.with_model_file("yoloe-v8l-seg-pf.onnx")
}
pub fn yoloe_11s_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::S)
+ Self::yolo_segment()
+ .with_version(11.into())
+ .with_scale(Scale::S)
.with_model_file("yoloe-11s-seg-pf.onnx")
}
pub fn yoloe_11m_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::M)
+ Self::yolo_segment()
+ .with_version(11.into())
+ .with_scale(Scale::M)
.with_model_file("yoloe-v8m-seg-pf.onnx")
}
pub fn yoloe_11l_seg_pf() -> Self {
- Self::yolo()
- .with_model_version(11.into())
- .with_model_scale(Scale::L)
+ Self::yolo_segment()
+ .with_version(11.into())
+ .with_scale(Scale::L)
.with_model_file("yoloe-11l-seg-pf.onnx")
}
+
+ /// ---- TODO
+ pub fn fastsam_s() -> Self {
+ Self::yolo_segment()
+ .with_scale(Scale::S)
+ .with_version(8.into())
+ .with_model_file("FastSAM-s.onnx")
+ }
+
+ pub fn yolo_v8_rtdetr_l() -> Self {
+ Self::yolo_detect()
+ .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
+ .with_scale(Scale::L)
+ .with_model_file("rtdetr-l-det.onnx")
+ }
}
diff --git a/src/models/yolo/impl.rs b/src/models/yolo/impl.rs
index a15d902..7bf550d 100644
--- a/src/models/yolo/impl.rs
+++ b/src/models/yolo/impl.rs
@@ -8,8 +8,8 @@ use regex::Regex;
use crate::{
elapsed,
models::{BoxType, YOLOPredsFormat},
- DynConf, Engine, Hbb, Image, Keypoint, Mask, NmsOps, Obb, Ops, Options, Prob, Processor, Task,
- Ts, Version, Xs, Y,
+ DynConf, Engine, Hbb, Image, Keypoint, Mask, ModelConfig, NmsOps, Obb, Ops, Prob, Processor,
+ Task, Ts, Version, Xs, Y,
};
#[derive(Debug, Builder)]
@@ -36,17 +36,18 @@ pub struct YOLO {
topk: usize,
}
-impl TryFrom for YOLO {
+impl TryFrom for YOLO {
type Error = anyhow::Error;
- fn try_from(options: Options) -> Result {
- Self::new(options)
+ fn try_from(config: ModelConfig) -> Result {
+ Self::new(config)
}
}
impl YOLO {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
+
let (batch, height, width, ts, spec) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&640.into()).opt(),
@@ -54,11 +55,8 @@ impl YOLO {
engine.ts.clone(),
engine.spec().to_owned(),
);
- let processor = options
- .to_processor()?
- .with_image_width(width as _)
- .with_image_height(height as _);
- let task: Option = match &options.model_task {
+
+ let task: Option = match &config.task {
Some(task) => Some(task.clone()),
None => match engine.try_fetch("task") {
Some(x) => match x.as_str() {
@@ -77,8 +75,8 @@ impl YOLO {
};
// Task & layout
- let version = options.model_version;
- let (layout, task) = match &options.yolo_preds_format {
+ let version = config.version;
+ let (layout, task) = match &config.yolo_preds_format {
// customized
Some(layout) => {
// check task
@@ -172,7 +170,7 @@ impl YOLO {
// Class names
let names: Option> = match Self::fetch_names_from_onnx(&engine) {
- Some(names_parsed) => match &options.class_names {
+ Some(names_parsed) => match &config.class_names {
Some(names) => {
if names.len() == names_parsed.len() {
// prioritize user-defined
@@ -188,25 +186,25 @@ impl YOLO {
}
None => Some(names_parsed),
},
- None => options.class_names.clone(),
+ None => config.class_names.clone(),
};
// Class names & Number of class
- let (nc, names) = match (options.nc(), names) {
+ let (nc, names) = match (config.nc(), names) {
(_, Some(names)) => (names.len(), names.to_vec()),
(Some(nc), None) => (nc, Self::n2s(nc)),
(None, None) => {
anyhow::bail!(
"Neither class names nor the number of classes were specified. \
- \nConsider specify them with `Options::default().with_nc()` or `Options::default().with_class_names()`"
+ \nConsider specify them with `ModelConfig::default().with_nc()` or `ModelConfig::default().with_class_names()`"
);
}
};
// Keypoint names & Number of keypoints
let (nk, names_kpt) = if let Task::KeypointsDetection = task {
- let nk = Self::fetch_nk_from_onnx(&engine).or(options.nk());
- match (&options.keypoint_names, nk) {
+ let nk = Self::fetch_nk_from_onnx(&engine).or(config.nk());
+ match (&config.keypoint_names, nk) {
(Some(names), Some(nk)) => {
if names.len() != nk {
anyhow::bail!(
@@ -221,7 +219,7 @@ impl YOLO {
(None, Some(nk)) => (nk, Self::n2s(nk)),
(None, None) => anyhow::bail!(
"Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \
- \nConsider specify them with `Options::default().with_nk()` or `Options::default().with_keypoint_names()`"
+ \nConsider specify them with `ModelConfig::default().with_nk()` or `ModelConfig::default().with_keypoint_names()`"
),
}
} else {
@@ -229,12 +227,12 @@ impl YOLO {
};
// Attributes
- let topk = options.topk().unwrap_or(5);
- let confs = DynConf::new(options.class_confs(), nc);
- let kconfs = DynConf::new(options.keypoint_confs(), nk);
- let iou = options.iou().unwrap_or(0.45);
- let classes_excluded = options.classes_excluded().to_vec();
- let classes_retained = options.classes_retained().to_vec();
+ let topk = config.topk().unwrap_or(5);
+ let confs = DynConf::new(config.class_confs(), nc);
+ let kconfs = DynConf::new(config.keypoint_confs(), nk);
+ let iou = config.iou().unwrap_or(0.45);
+ let classes_excluded = config.classes_excluded().to_vec();
+ let classes_retained = config.classes_retained().to_vec();
let mut info = format!(
"YOLO Version: {}, Task: {:?}, Category Count: {}, Keypoint Count: {}, TopK: {}",
version.map_or("Unknown".into(), |x| x.to_string()),
@@ -249,6 +247,10 @@ impl YOLO {
if !classes_retained.is_empty() {
info = format!("{}, classes_retained: {:?}", info, classes_retained);
}
+ let processor = Processor::try_from_config(&config.processor)?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
info!("{}", info);
Ok(Self {
diff --git a/src/models/yolop/config.rs b/src/models/yolop/config.rs
index 6e1564e..9736c7d 100644
--- a/src/models/yolop/config.rs
+++ b/src/models/yolop/config.rs
@@ -1,14 +1,12 @@
/// Model configuration for `YOLOP`
-impl crate::Options {
+impl crate::ModelConfig {
pub fn yolop() -> Self {
Self::default()
- .with_model_name("yolop")
+ .with_name("yolop")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 2, 640.into())
.with_model_ixx(0, 3, 640.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
- .with_resize_filter("Bilinear")
- .with_normalize(true)
.with_class_confs(&[0.3])
}
diff --git a/src/models/yolop/impl.rs b/src/models/yolop/impl.rs
index 1cbde83..b9a5c3f 100644
--- a/src/models/yolop/impl.rs
+++ b/src/models/yolop/impl.rs
@@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::{s, Array, Axis, IxDyn};
use crate::{
- elapsed, DynConf, Engine, Hbb, Image, NmsOps, Ops, Options, Polygon, Processor, Ts, Xs, Y,
+ elapsed, DynConf, Engine, Hbb, Image, ModelConfig, NmsOps, Ops, Polygon, Processor, Ts, Xs, Y,
};
#[derive(Builder, Debug)]
@@ -20,8 +20,8 @@ pub struct YOLOPv2 {
}
impl YOLOPv2 {
- pub fn new(options: Options) -> Result {
- let engine = options.to_engine()?;
+ pub fn new(config: ModelConfig) -> Result {
+ let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
@@ -29,14 +29,13 @@ impl YOLOPv2 {
engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(),
);
- let processor = options
- .to_processor()?
+
+ let confs = DynConf::new(config.class_confs(), 80);
+ let iou = config.iou.unwrap_or(0.45f32);
+ let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _)
.with_image_height(height as _);
- let confs = DynConf::new(options.class_confs(), 80);
- let iou = options.iou.unwrap_or(0.45f32);
-
Ok(Self {
engine,
height,
diff --git a/src/utils/device.rs b/src/utils/device.rs
index 8c612cd..a8099cf 100644
--- a/src/utils/device.rs
+++ b/src/utils/device.rs
@@ -1,16 +1,20 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Device {
- Auto(usize),
Cpu(usize),
Cuda(usize),
TensorRT(usize),
CoreML(usize),
}
+impl Default for Device {
+ fn default() -> Self {
+ Self::Cpu(0)
+ }
+}
+
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let x = match self {
- Self::Auto(i) => format!("auto:{}", i),
Self::Cpu(i) => format!("cpu:{}", i),
Self::Cuda(i) => format!("cuda:{}", i),
Self::CoreML(i) => format!("mps:{}", i),
@@ -47,7 +51,6 @@ impl TryFrom<&str> for Device {
impl Device {
pub fn id(&self) -> usize {
match self {
- Device::Auto(i) => *i,
Device::Cpu(i) => *i,
Device::Cuda(i) => *i,
Device::TensorRT(i) => *i,
diff --git a/src/utils/dtype.rs b/src/utils/dtype.rs
index 538ae83..5874007 100644
--- a/src/utils/dtype.rs
+++ b/src/utils/dtype.rs
@@ -1,5 +1,6 @@
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum DType {
+ #[default]
Auto,
Int4,
Int8,
diff --git a/src/utils/kind.rs b/src/utils/kind.rs
deleted file mode 100644
index 4519427..0000000
--- a/src/utils/kind.rs
+++ /dev/null
@@ -1,18 +0,0 @@
-#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
-pub enum Kind {
- // Do we really need this?
- Vision,
- Language,
- VisionLanguage,
-}
-
-impl std::fmt::Display for Kind {
- fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
- let x = match self {
- Self::Vision => "visual",
- Self::Language => "textual",
- Self::VisionLanguage => "vl",
- };
- write!(f, "{}", x)
- }
-}
diff --git a/src/utils/mod.rs b/src/utils/mod.rs
index 50b8068..008dfb6 100644
--- a/src/utils/mod.rs
+++ b/src/utils/mod.rs
@@ -2,13 +2,12 @@ mod device;
mod dtype;
mod dynconf;
mod iiix;
-mod kind;
mod logits_sampler;
mod min_opt_max;
mod names;
mod ops;
-mod options;
mod processor;
+mod processor_config;
mod retry;
mod scale;
mod task;
@@ -20,13 +19,12 @@ pub use device::Device;
pub use dtype::DType;
pub use dynconf::DynConf;
pub(crate) use iiix::Iiix;
-pub use kind::Kind;
pub use logits_sampler::LogitsSampler;
pub use min_opt_max::MinOptMax;
pub use names::*;
pub use ops::*;
-pub use options::*;
pub use processor::*;
+pub use processor_config::ProcessorConfig;
pub use scale::Scale;
pub use task::Task;
pub use traits::*;
diff --git a/src/utils/ops.rs b/src/utils/ops.rs
index 8d416ce..9f54162 100644
--- a/src/utils/ops.rs
+++ b/src/utils/ops.rs
@@ -267,12 +267,12 @@ impl Ops<'_> {
PixelType::F32,
)?;
let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type());
- let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
+ let (mut resizer, mut config) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
- options = options.crop(0., 0., w.into(), h.into());
+ config = config.crop(0., 0., w.into(), h.into());
};
- resizer.resize(&src, &mut dst, &options)?;
+ resizer.resize(&src, &mut dst, &config)?;
// u8 -> f32
Self::u8_slice_to_f32(&dst.into_vec())
@@ -317,12 +317,12 @@ impl Ops<'_> {
) -> Result> {
let src = Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), PixelType::U8)?;
let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type());
- let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
+ let (mut resizer, mut config) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
- options = options.crop(0., 0., w.into(), h.into());
+ config = config.crop(0., 0., w.into(), h.into());
};
- resizer.resize(&src, &mut dst, &options)?;
+ resizer.resize(&src, &mut dst, &config)?;
Ok(dst.into_vec())
}
@@ -348,13 +348,13 @@ impl Ops<'_> {
th: u32,
tw: u32,
resizer: &mut Resizer,
- options: &ResizeOptions,
+ config: &ResizeOptions,
) -> Result> {
let buffer = if x.dimensions() == (tw, th) {
x.to_rgb8().into_raw()
} else {
let mut dst = Image::new(tw, th, PixelType::U8x3);
- resizer.resize(x, &mut dst, options)?;
+ resizer.resize(x, &mut dst, config)?;
dst.into_vec()
};
let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?
@@ -370,9 +370,9 @@ impl Ops<'_> {
filter: &str,
) -> Result> {
let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn();
- let (mut resizer, options) = Self::build_resizer_filter(filter)?;
+ let (mut resizer, config) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
- let y = Self::resize_rgb(x, th, tw, &mut resizer, &options)?;
+ let y = Self::resize_rgb(x, th, tw, &mut resizer, &config)?;
ys.slice_mut(s![idx, .., .., ..]).assign(&y);
}
Ok(ys)
@@ -387,7 +387,7 @@ impl Ops<'_> {
resize_by: &str,
center: bool,
resizer: &mut Resizer,
- options: &ResizeOptions,
+ config: &ResizeOptions,
) -> Result> {
let (w0, h0) = x.dimensions();
let buffer = if w0 == tw && h0 == th {
@@ -403,7 +403,7 @@ impl Ops<'_> {
}
"height" => (th * w0 / h0, th),
"width" => (tw, tw * h0 / w0),
- _ => anyhow::bail!("ModelConfig for `letterbox`: width, height, auto"),
+ _ => anyhow::bail!("EngineConfig for `letterbox`: width, height, auto"),
};
let mut dst = Image::from_vec_u8(
@@ -422,7 +422,7 @@ impl Ops<'_> {
(0, 0)
};
let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?;
- resizer.resize(x, &mut dst_cropped, options)?;
+ resizer.resize(x, &mut dst_cropped, config)?;
dst.into_vec()
};
let y = Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?
@@ -441,9 +441,9 @@ impl Ops<'_> {
center: bool,
) -> Result> {
let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn();
- let (mut resizer, options) = Self::build_resizer_filter(filter)?;
+ let (mut resizer, config) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
- let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &options)?;
+ let y = Self::letterbox_rgb(x, th, tw, bg, resize_by, center, &mut resizer, &config)?;
ys.slice_mut(s![idx, .., .., ..]).assign(&y);
}
Ok(ys)
diff --git a/src/utils/options.rs b/src/utils/options.rs
deleted file mode 100644
index 613a620..0000000
--- a/src/utils/options.rs
+++ /dev/null
@@ -1,488 +0,0 @@
-//! Options for everthing
-use aksr::Builder;
-use anyhow::Result;
-use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
-
-use crate::{
- models::{SamKind, YOLOPredsFormat},
- try_fetch_file_stem, DType, Device, Engine, Hub, Iiix, Kind, LogitsSampler, MinOptMax,
- Processor, ResizeMode, Scale, Task, Version,
-};
-
-/// Options for building models and inference
-#[derive(Builder, Debug, Clone)]
-pub struct Options {
- // Model configs
- pub model_file: String,
- pub model_name: &'static str,
- pub model_device: Device,
- pub model_dtype: DType,
- pub model_version: Option,
- pub model_task: Option,
- pub model_scale: Option,
- pub model_kind: Option,
- pub model_iiixs: Vec,
- pub model_spec: String,
- pub model_num_dry_run: usize,
- pub trt_fp16: bool,
- pub profile: bool,
-
- // models
- pub model_encoder_file: Option,
- pub model_decoder_file: Option,
- pub visual_encoder_file: Option,
- pub visual_decoder_file: Option,
- pub textual_encoder_file: Option,
- pub textual_decoder_file: Option,
-
- // Processor configs
- #[args(except(setter))]
- pub image_width: u32,
- #[args(except(setter))]
- pub image_height: u32,
- pub resize_mode: ResizeMode,
- pub resize_filter: &'static str,
- pub padding_value: u8,
- pub letterbox_center: bool,
- pub normalize: bool,
- pub image_std: Vec,
- pub image_mean: Vec,
- pub nchw: bool,
- pub unsigned: bool,
-
- // Names
- pub class_names: Option>,
- pub class_names_2: Option>,
- pub class_names_3: Option>,
- pub keypoint_names: Option>,
- pub keypoint_names_2: Option>,
- pub keypoint_names_3: Option>,
- pub text_names: Option>,
- pub text_names_2: Option>,
- pub text_names_3: Option>,
- pub category_names: Option>,
- pub category_names_2: Option>,
- pub category_names_3: Option>,
-
- // Confs
- pub class_confs: Vec,
- pub class_confs_2: Vec,
- pub class_confs_3: Vec,
- pub keypoint_confs: Vec,
- pub keypoint_confs_2: Vec,
- pub keypoint_confs_3: Vec,
- pub text_confs: Vec,
- pub text_confs_2: Vec,
- pub text_confs_3: Vec,
-
- // Files
- pub file: Option,
- pub file_2: Option,
- pub file_3: Option,
-
- // For classification
- pub apply_softmax: Option,
- pub topk: Option,
- pub topk_2: Option,
- pub topk_3: Option,
-
- // For detection
- #[args(aka = "nc")]
- pub num_classes: Option,
- #[args(aka = "nk")]
- pub num_keypoints: Option,
- #[args(aka = "nm")]
- pub num_masks: Option,
- pub iou: Option