From 94fa832359688317e96c9762901f08aa4735ed55 Mon Sep 17 00:00:00 2001
From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com>
Date: Wed, 22 Jan 2025 23:49:51 +0800
Subject: [PATCH] Add Owlv2 model (#59)
---
README.md | 3 +-
examples/owlv2/README.md | 5 ++
examples/owlv2/main.rs | 70 +++++++++++++++++
src/models/mod.rs | 2 +
src/models/owl/README.md | 9 +++
src/models/owl/config.rs | 37 +++++++++
src/models/owl/impl.rs | 156 ++++++++++++++++++++++++++++++++++++++
src/models/owl/mod.rs | 4 +
src/models/rtdetr/impl.rs | 3 +-
src/xy/x.rs | 4 +-
10 files changed, 288 insertions(+), 5 deletions(-)
create mode 100644 examples/owlv2/README.md
create mode 100644 examples/owlv2/main.rs
create mode 100644 src/models/owl/README.md
create mode 100644 src/models/owl/config.rs
create mode 100644 src/models/owl/impl.rs
create mode 100644 src/models/owl/mod.rs
diff --git a/README.md b/README.md
index 85a2517..d78eb5c 100644
--- a/README.md
+++ b/README.md
@@ -86,7 +86,8 @@
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | ✅ | ✅ | ✅ | | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
-| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
+| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
+| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
diff --git a/examples/owlv2/README.md b/examples/owlv2/README.md
new file mode 100644
index 0000000..a6ee30e
--- /dev/null
+++ b/examples/owlv2/README.md
@@ -0,0 +1,5 @@
+## Quick Start
+
+```shell
+cargo run -r -F cuda --example owlv2 -- --device cuda:0 --dtype fp16
+```
diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs
new file mode 100644
index 0000000..577ae8a
--- /dev/null
+++ b/examples/owlv2/main.rs
@@ -0,0 +1,70 @@
+use anyhow::Result;
+use usls::{models::OWLv2, Annotator, DataLoader, Options};
+
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// dtype
+ #[argh(option, default = "String::from(\"auto\")")]
+ dtype: String,
+
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+
+ /// source image
+ #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")]
+ source: Vec,
+
+ /// open class names
+ #[argh(
+ option,
+ default = "vec![
+ String::from(\"person\"),
+ String::from(\"hand\"),
+ String::from(\"shoes\"),
+ String::from(\"bus\"),
+ String::from(\"car\"),
+ String::from(\"dog\"),
+ String::from(\"cat\"),
+ String::from(\"sign\"),
+ String::from(\"tie\"),
+ String::from(\"monitor\"),
+ String::from(\"glasses\"),
+ String::from(\"tree\"),
+ String::from(\"head\"),
+ ]"
+ )]
+ labels: Vec,
+}
+
+fn main() -> Result<()> {
+ tracing_subscriber::fmt()
+ .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
+ .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
+ .init();
+ let args: Args = argh::from_env();
+
+ // options
+ let options = Options::owlv2_base_ensemble()
+ // owlv2_base()
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::>())
+ .commit()?;
+ let mut model = OWLv2::new(options)?;
+
+ // load
+ let xs = DataLoader::try_read_batch(&args.source)?;
+
+ // run
+ let ys = model.forward(&xs)?;
+
+ // annotate
+ let annotator = Annotator::default()
+ .with_bboxes_thickness(3)
+ .with_saveout(model.spec());
+ annotator.annotate(&xs, &ys);
+
+ Ok(())
+}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index 9b5dc02..b80d974 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -17,6 +17,7 @@ mod linknet;
mod mobileone;
mod modnet;
mod moondream2;
+mod owl;
mod picodet;
mod pipeline;
mod rtdetr;
@@ -39,6 +40,7 @@ pub use florence2::*;
pub use grounding_dino::*;
pub use modnet::*;
pub use moondream2::*;
+pub use owl::*;
pub use picodet::*;
pub use pipeline::*;
pub use rtdetr::*;
diff --git a/src/models/owl/README.md b/src/models/owl/README.md
new file mode 100644
index 0000000..e750de0
--- /dev/null
+++ b/src/models/owl/README.md
@@ -0,0 +1,9 @@
+# OWLv2: Scaling Open-Vocabulary Object Detection
+
+## Official Repository
+
+The official repository can be found on: [Hugging Face](https://huggingface.co/google/owlv2-base-patch16-ensemble)
+
+## Example
+
+Refer to the [example](../../../examples/owlv2)
diff --git a/src/models/owl/config.rs b/src/models/owl/config.rs
new file mode 100644
index 0000000..cc17da9
--- /dev/null
+++ b/src/models/owl/config.rs
@@ -0,0 +1,37 @@
+/// Model configuration for `OWLv2`
+impl crate::Options {
+ pub fn owlv2() -> Self {
+ Self::default()
+ .with_model_name("owlv2")
+ .with_model_kind(crate::Kind::VisionLanguage)
+ // 1st & 3rd: text
+ .with_model_ixx(0, 0, (1, 1, 1).into()) // TODO
+ .with_model_ixx(0, 1, 1.into())
+ .with_model_ixx(2, 0, (1, 1, 1).into())
+ .with_model_ixx(2, 1, 1.into())
+ .with_model_max_length(16)
+ // 2nd: image
+ .with_model_ixx(1, 0, (1, 1, 1).into())
+ .with_model_ixx(1, 1, 3.into())
+ .with_model_ixx(1, 2, 960.into())
+ .with_model_ixx(1, 3, 960.into())
+ .with_image_mean(&[0.48145466, 0.4578275, 0.40821073])
+ .with_image_std(&[0.26862954, 0.261_302_6, 0.275_777_1])
+ .with_resize_mode(crate::ResizeMode::FitAdaptive)
+ .with_normalize(true)
+ .with_class_confs(&[0.1])
+ .with_model_num_dry_run(0)
+ }
+
+ pub fn owlv2_base() -> Self {
+ Self::owlv2().with_model_file("base-patch16.onnx")
+ }
+
+ pub fn owlv2_base_ensemble() -> Self {
+ Self::owlv2().with_model_file("base-patch16-ensemble.onnx")
+ }
+
+ pub fn owlv2_base_ft() -> Self {
+ Self::owlv2().with_model_file("base-patch16-ft.onnx")
+ }
+}
diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs
new file mode 100644
index 0000000..5125892
--- /dev/null
+++ b/src/models/owl/impl.rs
@@ -0,0 +1,156 @@
+use aksr::Builder;
+use anyhow::Result;
+use image::DynamicImage;
+use ndarray::{s, Axis};
+use rayon::prelude::*;
+
+use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
+
+#[derive(Debug, Builder)]
+pub struct OWLv2 {
+ engine: Engine,
+ height: usize,
+ width: usize,
+ batch: usize,
+ names: Vec,
+ names_with_prompt: Vec,
+ confs: DynConf,
+ ts: Ts,
+ processor: Processor,
+ spec: String,
+ input_ids: X,
+ attention_mask: X,
+}
+
+impl OWLv2 {
+ pub fn new(options: Options) -> Result {
+ let engine = options.to_engine()?;
+ let (batch, height, width, ts) = (
+ engine.batch().opt(),
+ engine.try_height().unwrap_or(&960.into()).opt(),
+ engine.try_width().unwrap_or(&960.into()).opt(),
+ engine.ts.clone(),
+ );
+ let spec = engine.spec().to_owned();
+ let processor = options
+ .to_processor()?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+ let names: Vec = options
+ .class_names()
+ .expect("No class names specified.")
+ .iter()
+ .map(|x| x.to_string())
+ .collect();
+ let names_with_prompt: Vec =
+ names.iter().map(|x| format!("a photo of {}", x)).collect();
+ let n = names.len();
+ let confs = DynConf::new(options.class_confs(), n);
+ let input_ids: Vec = processor
+ .encode_texts_ids(
+ &names_with_prompt
+ .iter()
+ .map(|x| x.as_str())
+ .collect::>(),
+ false,
+ )?
+ .into_iter()
+ .flatten()
+ .collect();
+ let input_ids: X = ndarray::Array2::from_shape_vec((n, input_ids.len() / n), input_ids)?
+ .into_dyn()
+ .into();
+ let attention_mask = X::ones_like(&input_ids);
+
+ Ok(Self {
+ engine,
+ height,
+ width,
+ batch,
+ spec,
+ names,
+ names_with_prompt,
+ confs,
+ ts,
+ processor,
+ input_ids,
+ attention_mask,
+ })
+ }
+
+ fn preprocess(&mut self, xs: &[DynamicImage]) -> Result {
+ let image_embeddings = self.processor.process_images(xs)?;
+ let xs = Xs::from(vec![
+ self.input_ids.clone(),
+ image_embeddings,
+ self.attention_mask.clone(),
+ ]);
+
+ Ok(xs)
+ }
+
+ fn inference(&mut self, xs: Xs) -> Result {
+ self.engine.run(xs)
+ }
+
+ pub fn forward(&mut self, xs: &[DynamicImage]) -> Result {
+ let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? });
+ let ys = elapsed!("inference", self.ts, { self.inference(ys)? });
+ let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? });
+
+ Ok(ys)
+ }
+
+ fn postprocess(&mut self, xs: Xs) -> Result {
+ let ys: Vec = xs[0]
+ .axis_iter(Axis(0))
+ .into_par_iter()
+ .zip(xs[1].axis_iter(Axis(0)).into_par_iter())
+ .enumerate()
+ .filter_map(|(idx, (clss, bboxes))| {
+ let (image_height, image_width) = self.processor.image0s_size[idx];
+ let ratio = image_height.max(image_width) as f32;
+ let y_bboxes: Vec = clss
+ .axis_iter(Axis(0))
+ .into_par_iter()
+ .enumerate()
+ .filter_map(|(i, clss_)| {
+ let (class_id, &confidence) = clss_
+ .into_iter()
+ .enumerate()
+ .max_by(|a, b| a.1.total_cmp(b.1))?;
+
+ let confidence = 1. / ((-confidence).exp() + 1.);
+ if confidence < self.confs[class_id] {
+ return None;
+ }
+
+ let bbox = bboxes.slice(s![i, ..]).mapv(|x| x * ratio);
+ let (x, y, w, h) = (
+ (bbox[0] - bbox[2] / 2.).max(0.0f32),
+ (bbox[1] - bbox[3] / 2.).max(0.0f32),
+ bbox[2],
+ bbox[3],
+ );
+
+ Some(
+ Bbox::default()
+ .with_xywh(x, y, w, h)
+ .with_confidence(confidence)
+ .with_id(class_id as isize)
+ .with_name(&self.names[class_id]),
+ )
+ })
+ .collect();
+
+ Some(Y::default().with_bboxes(&y_bboxes))
+ })
+ .collect();
+
+ Ok(ys.into())
+ }
+
+ pub fn summary(&mut self) {
+ self.ts.summary();
+ }
+}
diff --git a/src/models/owl/mod.rs b/src/models/owl/mod.rs
new file mode 100644
index 0000000..fbd2b75
--- /dev/null
+++ b/src/models/owl/mod.rs
@@ -0,0 +1,4 @@
+mod config;
+mod r#impl;
+
+pub use r#impl::*;
diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs
index 70e5262..d33f249 100644
--- a/src/models/rtdetr/impl.rs
+++ b/src/models/rtdetr/impl.rs
@@ -1,11 +1,10 @@
-use crate::{elapsed, Bbox, DynConf, Engine, Processor, Ts, Xs, Ys, X, Y};
use aksr::Builder;
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Axis};
use rayon::prelude::*;
-use crate::Options;
+use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
#[derive(Debug, Builder)]
pub struct RTDETR {
diff --git a/src/xy/x.rs b/src/xy/x.rs
index 17c07ec..a176585 100644
--- a/src/xy/x.rs
+++ b/src/xy/x.rs
@@ -75,11 +75,11 @@ impl X {
Self::from(Array::ones(Dim(IxDynImpl::from(shape.to_vec()))))
}
- pub fn zeros_like(x: Self) -> Self {
+ pub fn zeros_like(x: &Self) -> Self {
Self::from(Array::zeros(x.raw_dim()))
}
- pub fn ones_like(x: Self) -> Self {
+ pub fn ones_like(x: &Self) -> Self {
Self::from(Array::ones(x.raw_dim()))
}