diff --git a/README.md b/README.md index 85a2517..d78eb5c 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,8 @@ | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | ✅ | | [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | ✅ | ✅ | ✅ | | | | [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | | -| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | | +| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | | +| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | | diff --git a/examples/owlv2/README.md b/examples/owlv2/README.md new file mode 100644 index 0000000..a6ee30e --- /dev/null +++ b/examples/owlv2/README.md @@ -0,0 +1,5 @@ +## Quick Start + +```shell +cargo run -r -F cuda --example owlv2 -- --device cuda:0 --dtype fp16 +``` diff --git a/examples/owlv2/main.rs b/examples/owlv2/main.rs new file mode 100644 index 0000000..577ae8a --- /dev/null +++ b/examples/owlv2/main.rs @@ -0,0 +1,70 @@ +use anyhow::Result; +use usls::{models::OWLv2, Annotator, DataLoader, Options}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// dtype + #[argh(option, default = "String::from(\"auto\")")] + dtype: String, + + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// source image + #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] + source: Vec, + + /// open class names + #[argh( + option, + default = "vec![ + String::from(\"person\"), + String::from(\"hand\"), + String::from(\"shoes\"), + String::from(\"bus\"), + String::from(\"car\"), + String::from(\"dog\"), + String::from(\"cat\"), + String::from(\"sign\"), + String::from(\"tie\"), + String::from(\"monitor\"), + String::from(\"glasses\"), + String::from(\"tree\"), + String::from(\"head\"), + ]" + )] + labels: Vec, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + let args: Args = argh::from_env(); + + // options + let options = Options::owlv2_base_ensemble() + // owlv2_base() + .with_model_dtype(args.dtype.as_str().try_into()?) + .with_model_device(args.device.as_str().try_into()?) + .with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) + .commit()?; + let mut model = OWLv2::new(options)?; + + // load + let xs = DataLoader::try_read_batch(&args.source)?; + + // run + let ys = model.forward(&xs)?; + + // annotate + let annotator = Annotator::default() + .with_bboxes_thickness(3) + .with_saveout(model.spec()); + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 9b5dc02..b80d974 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -17,6 +17,7 @@ mod linknet; mod mobileone; mod modnet; mod moondream2; +mod owl; mod picodet; mod pipeline; mod rtdetr; @@ -39,6 +40,7 @@ pub use florence2::*; pub use grounding_dino::*; pub use modnet::*; pub use moondream2::*; +pub use owl::*; pub use picodet::*; pub use pipeline::*; pub use rtdetr::*; diff --git a/src/models/owl/README.md b/src/models/owl/README.md new file mode 100644 index 0000000..e750de0 --- /dev/null +++ b/src/models/owl/README.md @@ -0,0 +1,9 @@ +# OWLv2: Scaling Open-Vocabulary Object Detection + +## Official Repository + +The official repository can be found on: [Hugging Face](https://huggingface.co/google/owlv2-base-patch16-ensemble) + +## Example + +Refer to the [example](../../../examples/owlv2) diff --git a/src/models/owl/config.rs b/src/models/owl/config.rs new file mode 100644 index 0000000..cc17da9 --- /dev/null +++ b/src/models/owl/config.rs @@ -0,0 +1,37 @@ +/// Model configuration for `OWLv2` +impl crate::Options { + pub fn owlv2() -> Self { + Self::default() + .with_model_name("owlv2") + .with_model_kind(crate::Kind::VisionLanguage) + // 1st & 3rd: text + .with_model_ixx(0, 0, (1, 1, 1).into()) // TODO + .with_model_ixx(0, 1, 1.into()) + .with_model_ixx(2, 0, (1, 1, 1).into()) + .with_model_ixx(2, 1, 1.into()) + .with_model_max_length(16) + // 2nd: image + .with_model_ixx(1, 0, (1, 1, 1).into()) + .with_model_ixx(1, 1, 3.into()) + .with_model_ixx(1, 2, 960.into()) + .with_model_ixx(1, 3, 960.into()) + .with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) + .with_image_std(&[0.26862954, 0.261_302_6, 0.275_777_1]) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_normalize(true) + .with_class_confs(&[0.1]) + .with_model_num_dry_run(0) + } + + pub fn owlv2_base() -> Self { + Self::owlv2().with_model_file("base-patch16.onnx") + } + + pub fn owlv2_base_ensemble() -> Self { + Self::owlv2().with_model_file("base-patch16-ensemble.onnx") + } + + pub fn owlv2_base_ft() -> Self { + Self::owlv2().with_model_file("base-patch16-ft.onnx") + } +} diff --git a/src/models/owl/impl.rs b/src/models/owl/impl.rs new file mode 100644 index 0000000..5125892 --- /dev/null +++ b/src/models/owl/impl.rs @@ -0,0 +1,156 @@ +use aksr::Builder; +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Axis}; +use rayon::prelude::*; + +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; + +#[derive(Debug, Builder)] +pub struct OWLv2 { + engine: Engine, + height: usize, + width: usize, + batch: usize, + names: Vec, + names_with_prompt: Vec, + confs: DynConf, + ts: Ts, + processor: Processor, + spec: String, + input_ids: X, + attention_mask: X, +} + +impl OWLv2 { + pub fn new(options: Options) -> Result { + let engine = options.to_engine()?; + let (batch, height, width, ts) = ( + engine.batch().opt(), + engine.try_height().unwrap_or(&960.into()).opt(), + engine.try_width().unwrap_or(&960.into()).opt(), + engine.ts.clone(), + ); + let spec = engine.spec().to_owned(); + let processor = options + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let names: Vec = options + .class_names() + .expect("No class names specified.") + .iter() + .map(|x| x.to_string()) + .collect(); + let names_with_prompt: Vec = + names.iter().map(|x| format!("a photo of {}", x)).collect(); + let n = names.len(); + let confs = DynConf::new(options.class_confs(), n); + let input_ids: Vec = processor + .encode_texts_ids( + &names_with_prompt + .iter() + .map(|x| x.as_str()) + .collect::>(), + false, + )? + .into_iter() + .flatten() + .collect(); + let input_ids: X = ndarray::Array2::from_shape_vec((n, input_ids.len() / n), input_ids)? + .into_dyn() + .into(); + let attention_mask = X::ones_like(&input_ids); + + Ok(Self { + engine, + height, + width, + batch, + spec, + names, + names_with_prompt, + confs, + ts, + processor, + input_ids, + attention_mask, + }) + } + + fn preprocess(&mut self, xs: &[DynamicImage]) -> Result { + let image_embeddings = self.processor.process_images(xs)?; + let xs = Xs::from(vec![ + self.input_ids.clone(), + image_embeddings, + self.attention_mask.clone(), + ]); + + Ok(xs) + } + + fn inference(&mut self, xs: Xs) -> Result { + self.engine.run(xs) + } + + pub fn forward(&mut self, xs: &[DynamicImage]) -> Result { + let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); + let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); + let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); + + Ok(ys) + } + + fn postprocess(&mut self, xs: Xs) -> Result { + let ys: Vec = xs[0] + .axis_iter(Axis(0)) + .into_par_iter() + .zip(xs[1].axis_iter(Axis(0)).into_par_iter()) + .enumerate() + .filter_map(|(idx, (clss, bboxes))| { + let (image_height, image_width) = self.processor.image0s_size[idx]; + let ratio = image_height.max(image_width) as f32; + let y_bboxes: Vec = clss + .axis_iter(Axis(0)) + .into_par_iter() + .enumerate() + .filter_map(|(i, clss_)| { + let (class_id, &confidence) = clss_ + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1))?; + + let confidence = 1. / ((-confidence).exp() + 1.); + if confidence < self.confs[class_id] { + return None; + } + + let bbox = bboxes.slice(s![i, ..]).mapv(|x| x * ratio); + let (x, y, w, h) = ( + (bbox[0] - bbox[2] / 2.).max(0.0f32), + (bbox[1] - bbox[3] / 2.).max(0.0f32), + bbox[2], + bbox[3], + ); + + Some( + Bbox::default() + .with_xywh(x, y, w, h) + .with_confidence(confidence) + .with_id(class_id as isize) + .with_name(&self.names[class_id]), + ) + }) + .collect(); + + Some(Y::default().with_bboxes(&y_bboxes)) + }) + .collect(); + + Ok(ys.into()) + } + + pub fn summary(&mut self) { + self.ts.summary(); + } +} diff --git a/src/models/owl/mod.rs b/src/models/owl/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/owl/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/models/rtdetr/impl.rs b/src/models/rtdetr/impl.rs index 70e5262..d33f249 100644 --- a/src/models/rtdetr/impl.rs +++ b/src/models/rtdetr/impl.rs @@ -1,11 +1,10 @@ -use crate::{elapsed, Bbox, DynConf, Engine, Processor, Ts, Xs, Ys, X, Y}; use aksr::Builder; use anyhow::Result; use image::DynamicImage; use ndarray::{s, Axis}; use rayon::prelude::*; -use crate::Options; +use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; #[derive(Debug, Builder)] pub struct RTDETR { diff --git a/src/xy/x.rs b/src/xy/x.rs index 17c07ec..a176585 100644 --- a/src/xy/x.rs +++ b/src/xy/x.rs @@ -75,11 +75,11 @@ impl X { Self::from(Array::ones(Dim(IxDynImpl::from(shape.to_vec())))) } - pub fn zeros_like(x: Self) -> Self { + pub fn zeros_like(x: &Self) -> Self { Self::from(Array::zeros(x.raw_dim())) } - pub fn ones_like(x: Self) -> Self { + pub fn ones_like(x: &Self) -> Self { Self::from(Array::ones(x.raw_dim())) }