diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs index 78c6493..389ad89 100644 --- a/examples/grounding-dino/main.rs +++ b/examples/grounding-dino/main.rs @@ -20,18 +20,18 @@ struct Args { #[argh( option, default = "vec![ - String::from(\"person\"), - String::from(\"hand\"), - String::from(\"shoes\"), - String::from(\"bus\"), - String::from(\"dog\"), - String::from(\"cat\"), - String::from(\"sign\"), - String::from(\"tie\"), - String::from(\"monitor\"), - String::from(\"glasses\"), - String::from(\"tree\"), - String::from(\"head\"), + String::from(\"person\"), + String::from(\"a hand\"), + String::from(\"a shoe\"), + String::from(\"bus\"), + String::from(\"dog\"), + String::from(\"cat\"), + String::from(\"sign\"), + String::from(\"tie\"), + String::from(\"monitor\"), + String::from(\"glasses\"), + String::from(\"tree\"), + String::from(\"head\"), ]" )] labels: Vec, @@ -49,6 +49,8 @@ fn main() -> Result<()> { .with_model_dtype(args.dtype.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::>()) + .with_class_confs(&[0.25]) + .with_text_confs(&[0.25]) .commit()?; let mut model = GroundingDINO::new(options)?; diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs index 4c54ee0..0ec4f00 100644 --- a/src/models/grounding_dino/config.rs +++ b/src/models/grounding_dino/config.rs @@ -12,8 +12,8 @@ impl crate::Options { .with_image_mean(&[0.485, 0.456, 0.406]) .with_image_std(&[0.229, 0.224, 0.225]) .with_normalize(true) - .with_class_confs(&[0.4]) - .with_text_confs(&[0.3]) + .with_class_confs(&[0.25]) + .with_text_confs(&[0.25]) } pub fn grounding_dino_tiny() -> Self { diff --git a/src/models/grounding_dino/impl.rs b/src/models/grounding_dino/impl.rs index 46e8fc8..b04026c 100644 --- a/src/models/grounding_dino/impl.rs +++ b/src/models/grounding_dino/impl.rs @@ -1,8 +1,9 @@ use aksr::Builder; use anyhow::Result; use image::DynamicImage; -use ndarray::{s, Array, Axis}; +use ndarray::{s, Array2, Axis}; use rayon::prelude::*; +use std::fmt::Write; use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; @@ -15,6 +16,7 @@ pub struct GroundingDINO { confs_visual: DynConf, confs_textual: DynConf, class_names: Vec, + class_ids_map: Vec>, tokens: Vec, token_ids: Vec, ts: Ts, @@ -26,7 +28,6 @@ impl GroundingDINO { pub fn new(options: Options) -> Result { let engine = options.to_engine()?; let spec = engine.spec().to_string(); - let (batch, height, width, ts) = ( engine.batch().opt(), engine.try_height().unwrap_or(&800.into()).opt(), @@ -37,20 +38,27 @@ impl GroundingDINO { .to_processor()? .with_image_width(width as _) .with_image_height(height as _); - let confs_visual = DynConf::new(options.class_confs(), 1); - let confs_textual = DynConf::new(options.text_confs(), 1); - - let class_names = Self::parse_texts( - &options - .text_names - .expect("No class names specified!") - .iter() - .map(|x| x.as_str()) - .collect::>(), - ); - let token_ids = processor.encode_text_ids(&class_names, true)?; - let tokens = processor.encode_text_tokens(&class_names, true)?; - let class_names = tokens.clone(); + let class_names = options + .text_names + .as_ref() + .and_then(|v| { + let v: Vec<_> = v + .iter() + .map(|s| s.trim().to_ascii_lowercase()) + .filter(|s| !s.is_empty()) + .collect(); + (!v.is_empty()).then_some(v) + }) + .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?; + let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| { + write!(&mut acc, "{}.", text).unwrap(); + acc + }); + let token_ids = processor.encode_text_ids(&text_prompt, true)?; + let tokens = processor.encode_text_tokens(&text_prompt, true)?; + let class_ids_map = Self::process_class_ids(&tokens); + let confs_visual = DynConf::new(options.class_confs(), class_names.len()); + let confs_textual = DynConf::new(options.text_confs(), class_names.len()); Ok(Self { engine, @@ -65,6 +73,7 @@ impl GroundingDINO { ts, processor, spec, + class_ids_map, }) } @@ -73,30 +82,17 @@ impl GroundingDINO { let image_embeddings = self.processor.process_images(xs)?; // encode texts - let tokens_f32 = self - .tokens - .iter() - .map(|x| if x == "." { 1. } else { 0. }) - .collect::>(); - - // input_ids let input_ids = X::from(self.token_ids.clone()) .insert_axis(0)? .repeat(0, self.batch)?; - - // token_type_ids - let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]); - - // attention_mask - let attention_mask = X::ones(&[self.batch, tokens_f32.len()]); - - // text_self_attention_masks - let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)? + let token_type_ids = X::zeros(&[self.batch, self.tokens.len()]); + let attention_mask = X::ones(&[self.batch, self.tokens.len()]); + let (text_self_attention_masks, position_ids) = + Self::gen_text_attn_masks_and_pos_ids(&self.token_ids)?; + let text_self_attention_masks = text_self_attention_masks .insert_axis(0)? .repeat(0, self.batch)?; - - // position_ids - let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?; + let position_ids = position_ids.insert_axis(0)?.repeat(0, self.batch)?; // inputs let xs = Xs::from(vec![ @@ -135,7 +131,6 @@ impl GroundingDINO { .filter_map(|(idx, logits)| { let (image_height, image_width) = self.processor.image0s_size[idx]; let ratio = self.processor.scale_factors_hw[idx][0]; - let y_bboxes: Vec = logits .axis_iter(Axis(0)) .into_par_iter() @@ -161,13 +156,15 @@ impl GroundingDINO { let x = x.max(0.0).min(image_width as _); let y = y.max(0.0).min(image_height as _); - Some( - Bbox::default() - .with_xywh(x, y, w, h) - .with_id(class_id as _) - .with_name(&self.class_names[class_id]) - .with_confidence(conf), - ) + self.class_ids_map[class_id].map(|c| { + let mut bbox = + Bbox::default().with_xywh(x, y, w, h).with_confidence(conf); + + if conf > self.confs_textual[c] { + bbox = bbox.with_name(&self.class_names[c]).with_id(c as _); + } + bbox + }) }) .collect(); @@ -182,42 +179,69 @@ impl GroundingDINO { Ok(ys.into()) } - fn parse_texts(texts: &[&str]) -> String { - let mut y = String::new(); - for text in texts.iter() { - if !text.is_empty() { - y.push_str(&format!("{} . ", text)); + fn gen_text_attn_masks_and_pos_ids(input_ids: &[f32]) -> Result<(X, X)> { + let n = input_ids.len(); + let mut vs: Vec = input_ids + .iter() + .map(|&x| { + if (x - 101.0).abs() < f32::EPSILON + || (x - 1012.0).abs() < f32::EPSILON + || (x - 102.0).abs() < f32::EPSILON + { + 1.0 + } else { + 0.0 + } + }) + .collect(); + vs[0] = 1.0; + vs[n - 1] = 1.0; + + let special_idxs: Vec = vs + .iter() + .enumerate() + .filter_map(|(i, &v)| { + if (v - 1.0).abs() < f32::EPSILON { + Some(i) + } else { + None + } + }) + .collect(); + let mut attn = Array2::::eye(n); + let mut pos_ids = vec![0f32; n]; + let mut prev = special_idxs[0]; + for &idx in special_idxs.iter() { + if idx == 0 || idx == n - 1 { + } else { + for r in (prev + 1)..=idx { + for c in (prev + 1)..=idx { + attn[[r, c]] = 1.0; + } + } + for (offset, pos_id) in pos_ids[prev + 1..=idx].iter_mut().enumerate() { + *pos_id = offset as f32; + } } + prev = idx; } - y + + Ok((X::from(attn.into_dyn()), X::from(pos_ids))) } - fn gen_text_self_attention_masks(tokens: &[f32]) -> Result { - let mut vs = tokens.to_vec(); - let n = vs.len(); - vs[0] = 1.; - vs[n - 1] = 1.; - let mut ys = Array::zeros((n, n)).into_dyn(); - let mut i_last = -1; - for (i, &v) in vs.iter().enumerate() { - if v == 0. { - if i_last == -1 { - i_last = i as isize; - } else { - i_last = -1; - } - } else if v == 1. { - if i_last == -1 { - ys.slice_mut(s![i, i]).fill(1.); - } else { - ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1]) - .fill(1.); - } - i_last = -1; + fn process_class_ids(tokens: &[String]) -> Vec> { + let mut result = Vec::with_capacity(tokens.len()); + let mut idx = 0; + for token in tokens { + if token == "[CLS]" || token == "[SEP]" { + result.push(None); + } else if token == "." { + result.push(None); + idx += 1; } else { - continue; + result.push(Some(idx)); } } - Ok(X::from(ys)) + result } }