Fix errors related to GroundingDino class_names (#71)

This commit is contained in:
Jamjamjon
2025-03-29 15:34:18 +08:00
committed by GitHub
parent 293c7d2e38
commit 118690402d
3 changed files with 113 additions and 87 deletions

View File

@ -20,18 +20,18 @@ struct Args {
#[argh( #[argh(
option, option,
default = "vec![ default = "vec![
String::from(\"person\"), String::from(\"person\"),
String::from(\"hand\"), String::from(\"a hand\"),
String::from(\"shoes\"), String::from(\"a shoe\"),
String::from(\"bus\"), String::from(\"bus\"),
String::from(\"dog\"), String::from(\"dog\"),
String::from(\"cat\"), String::from(\"cat\"),
String::from(\"sign\"), String::from(\"sign\"),
String::from(\"tie\"), String::from(\"tie\"),
String::from(\"monitor\"), String::from(\"monitor\"),
String::from(\"glasses\"), String::from(\"glasses\"),
String::from(\"tree\"), String::from(\"tree\"),
String::from(\"head\"), String::from(\"head\"),
]" ]"
)] )]
labels: Vec<String>, labels: Vec<String>,
@ -49,6 +49,8 @@ fn main() -> Result<()> {
.with_model_dtype(args.dtype.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>()) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
.with_class_confs(&[0.25])
.with_text_confs(&[0.25])
.commit()?; .commit()?;
let mut model = GroundingDINO::new(options)?; let mut model = GroundingDINO::new(options)?;

View File

@ -12,8 +12,8 @@ impl crate::Options {
.with_image_mean(&[0.485, 0.456, 0.406]) .with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225]) .with_image_std(&[0.229, 0.224, 0.225])
.with_normalize(true) .with_normalize(true)
.with_class_confs(&[0.4]) .with_class_confs(&[0.25])
.with_text_confs(&[0.3]) .with_text_confs(&[0.25])
} }
pub fn grounding_dino_tiny() -> Self { pub fn grounding_dino_tiny() -> Self {

View File

@ -1,8 +1,9 @@
use aksr::Builder; use aksr::Builder;
use anyhow::Result; use anyhow::Result;
use image::DynamicImage; use image::DynamicImage;
use ndarray::{s, Array, Axis}; use ndarray::{s, Array2, Axis};
use rayon::prelude::*; use rayon::prelude::*;
use std::fmt::Write;
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
@ -15,6 +16,7 @@ pub struct GroundingDINO {
confs_visual: DynConf, confs_visual: DynConf,
confs_textual: DynConf, confs_textual: DynConf,
class_names: Vec<String>, class_names: Vec<String>,
class_ids_map: Vec<Option<usize>>,
tokens: Vec<String>, tokens: Vec<String>,
token_ids: Vec<f32>, token_ids: Vec<f32>,
ts: Ts, ts: Ts,
@ -26,7 +28,6 @@ impl GroundingDINO {
pub fn new(options: Options) -> Result<Self> { pub fn new(options: Options) -> Result<Self> {
let engine = options.to_engine()?; let engine = options.to_engine()?;
let spec = engine.spec().to_string(); let spec = engine.spec().to_string();
let (batch, height, width, ts) = ( let (batch, height, width, ts) = (
engine.batch().opt(), engine.batch().opt(),
engine.try_height().unwrap_or(&800.into()).opt(), engine.try_height().unwrap_or(&800.into()).opt(),
@ -37,20 +38,27 @@ impl GroundingDINO {
.to_processor()? .to_processor()?
.with_image_width(width as _) .with_image_width(width as _)
.with_image_height(height as _); .with_image_height(height as _);
let confs_visual = DynConf::new(options.class_confs(), 1); let class_names = options
let confs_textual = DynConf::new(options.text_confs(), 1); .text_names
.as_ref()
let class_names = Self::parse_texts( .and_then(|v| {
&options let v: Vec<_> = v
.text_names .iter()
.expect("No class names specified!") .map(|s| s.trim().to_ascii_lowercase())
.iter() .filter(|s| !s.is_empty())
.map(|x| x.as_str()) .collect();
.collect::<Vec<_>>(), (!v.is_empty()).then_some(v)
); })
let token_ids = processor.encode_text_ids(&class_names, true)?; .ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
let tokens = processor.encode_text_tokens(&class_names, true)?; let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
let class_names = tokens.clone(); write!(&mut acc, "{}.", text).unwrap();
acc
});
let token_ids = processor.encode_text_ids(&text_prompt, true)?;
let tokens = processor.encode_text_tokens(&text_prompt, true)?;
let class_ids_map = Self::process_class_ids(&tokens);
let confs_visual = DynConf::new(options.class_confs(), class_names.len());
let confs_textual = DynConf::new(options.text_confs(), class_names.len());
Ok(Self { Ok(Self {
engine, engine,
@ -65,6 +73,7 @@ impl GroundingDINO {
ts, ts,
processor, processor,
spec, spec,
class_ids_map,
}) })
} }
@ -73,30 +82,17 @@ impl GroundingDINO {
let image_embeddings = self.processor.process_images(xs)?; let image_embeddings = self.processor.process_images(xs)?;
// encode texts // encode texts
let tokens_f32 = self
.tokens
.iter()
.map(|x| if x == "." { 1. } else { 0. })
.collect::<Vec<_>>();
// input_ids
let input_ids = X::from(self.token_ids.clone()) let input_ids = X::from(self.token_ids.clone())
.insert_axis(0)? .insert_axis(0)?
.repeat(0, self.batch)?; .repeat(0, self.batch)?;
let token_type_ids = X::zeros(&[self.batch, self.tokens.len()]);
// token_type_ids let attention_mask = X::ones(&[self.batch, self.tokens.len()]);
let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]); let (text_self_attention_masks, position_ids) =
Self::gen_text_attn_masks_and_pos_ids(&self.token_ids)?;
// attention_mask let text_self_attention_masks = text_self_attention_masks
let attention_mask = X::ones(&[self.batch, tokens_f32.len()]);
// text_self_attention_masks
let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)?
.insert_axis(0)? .insert_axis(0)?
.repeat(0, self.batch)?; .repeat(0, self.batch)?;
let position_ids = position_ids.insert_axis(0)?.repeat(0, self.batch)?;
// position_ids
let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?;
// inputs // inputs
let xs = Xs::from(vec![ let xs = Xs::from(vec![
@ -135,7 +131,6 @@ impl GroundingDINO {
.filter_map(|(idx, logits)| { .filter_map(|(idx, logits)| {
let (image_height, image_width) = self.processor.image0s_size[idx]; let (image_height, image_width) = self.processor.image0s_size[idx];
let ratio = self.processor.scale_factors_hw[idx][0]; let ratio = self.processor.scale_factors_hw[idx][0];
let y_bboxes: Vec<Bbox> = logits let y_bboxes: Vec<Bbox> = logits
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.into_par_iter() .into_par_iter()
@ -161,13 +156,15 @@ impl GroundingDINO {
let x = x.max(0.0).min(image_width as _); let x = x.max(0.0).min(image_width as _);
let y = y.max(0.0).min(image_height as _); let y = y.max(0.0).min(image_height as _);
Some( self.class_ids_map[class_id].map(|c| {
Bbox::default() let mut bbox =
.with_xywh(x, y, w, h) Bbox::default().with_xywh(x, y, w, h).with_confidence(conf);
.with_id(class_id as _)
.with_name(&self.class_names[class_id]) if conf > self.confs_textual[c] {
.with_confidence(conf), bbox = bbox.with_name(&self.class_names[c]).with_id(c as _);
) }
bbox
})
}) })
.collect(); .collect();
@ -182,42 +179,69 @@ impl GroundingDINO {
Ok(ys.into()) Ok(ys.into())
} }
fn parse_texts(texts: &[&str]) -> String { fn gen_text_attn_masks_and_pos_ids(input_ids: &[f32]) -> Result<(X, X)> {
let mut y = String::new(); let n = input_ids.len();
for text in texts.iter() { let mut vs: Vec<f32> = input_ids
if !text.is_empty() { .iter()
y.push_str(&format!("{} . ", text)); .map(|&x| {
if (x - 101.0).abs() < f32::EPSILON
|| (x - 1012.0).abs() < f32::EPSILON
|| (x - 102.0).abs() < f32::EPSILON
{
1.0
} else {
0.0
}
})
.collect();
vs[0] = 1.0;
vs[n - 1] = 1.0;
let special_idxs: Vec<usize> = vs
.iter()
.enumerate()
.filter_map(|(i, &v)| {
if (v - 1.0).abs() < f32::EPSILON {
Some(i)
} else {
None
}
})
.collect();
let mut attn = Array2::<f32>::eye(n);
let mut pos_ids = vec![0f32; n];
let mut prev = special_idxs[0];
for &idx in special_idxs.iter() {
if idx == 0 || idx == n - 1 {
} else {
for r in (prev + 1)..=idx {
for c in (prev + 1)..=idx {
attn[[r, c]] = 1.0;
}
}
for (offset, pos_id) in pos_ids[prev + 1..=idx].iter_mut().enumerate() {
*pos_id = offset as f32;
}
} }
prev = idx;
} }
y
Ok((X::from(attn.into_dyn()), X::from(pos_ids)))
} }
fn gen_text_self_attention_masks(tokens: &[f32]) -> Result<X> { fn process_class_ids(tokens: &[String]) -> Vec<Option<usize>> {
let mut vs = tokens.to_vec(); let mut result = Vec::with_capacity(tokens.len());
let n = vs.len(); let mut idx = 0;
vs[0] = 1.; for token in tokens {
vs[n - 1] = 1.; if token == "[CLS]" || token == "[SEP]" {
let mut ys = Array::zeros((n, n)).into_dyn(); result.push(None);
let mut i_last = -1; } else if token == "." {
for (i, &v) in vs.iter().enumerate() { result.push(None);
if v == 0. { idx += 1;
if i_last == -1 {
i_last = i as isize;
} else {
i_last = -1;
}
} else if v == 1. {
if i_last == -1 {
ys.slice_mut(s![i, i]).fill(1.);
} else {
ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
.fill(1.);
}
i_last = -1;
} else { } else {
continue; result.push(Some(idx));
} }
} }
Ok(X::from(ys)) result
} }
} }