Fix errors related to GroundingDino class_names (#71)

This commit is contained in:
Jamjamjon
2025-03-29 15:34:18 +08:00
committed by GitHub
parent 293c7d2e38
commit 118690402d
3 changed files with 113 additions and 87 deletions

View File

@ -21,8 +21,8 @@ struct Args {
option,
default = "vec![
String::from(\"person\"),
String::from(\"hand\"),
String::from(\"shoes\"),
String::from(\"a hand\"),
String::from(\"a shoe\"),
String::from(\"bus\"),
String::from(\"dog\"),
String::from(\"cat\"),
@ -49,6 +49,8 @@ fn main() -> Result<()> {
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
.with_class_confs(&[0.25])
.with_text_confs(&[0.25])
.commit()?;
let mut model = GroundingDINO::new(options)?;

View File

@ -12,8 +12,8 @@ impl crate::Options {
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
.with_normalize(true)
.with_class_confs(&[0.4])
.with_text_confs(&[0.3])
.with_class_confs(&[0.25])
.with_text_confs(&[0.25])
}
pub fn grounding_dino_tiny() -> Self {

View File

@ -1,8 +1,9 @@
use aksr::Builder;
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis};
use ndarray::{s, Array2, Axis};
use rayon::prelude::*;
use std::fmt::Write;
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
@ -15,6 +16,7 @@ pub struct GroundingDINO {
confs_visual: DynConf,
confs_textual: DynConf,
class_names: Vec<String>,
class_ids_map: Vec<Option<usize>>,
tokens: Vec<String>,
token_ids: Vec<f32>,
ts: Ts,
@ -26,7 +28,6 @@ impl GroundingDINO {
pub fn new(options: Options) -> Result<Self> {
let engine = options.to_engine()?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&800.into()).opt(),
@ -37,20 +38,27 @@ impl GroundingDINO {
.to_processor()?
.with_image_width(width as _)
.with_image_height(height as _);
let confs_visual = DynConf::new(options.class_confs(), 1);
let confs_textual = DynConf::new(options.text_confs(), 1);
let class_names = Self::parse_texts(
&options
let class_names = options
.text_names
.expect("No class names specified!")
.as_ref()
.and_then(|v| {
let v: Vec<_> = v
.iter()
.map(|x| x.as_str())
.collect::<Vec<_>>(),
);
let token_ids = processor.encode_text_ids(&class_names, true)?;
let tokens = processor.encode_text_tokens(&class_names, true)?;
let class_names = tokens.clone();
.map(|s| s.trim().to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect();
(!v.is_empty()).then_some(v)
})
.ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
write!(&mut acc, "{}.", text).unwrap();
acc
});
let token_ids = processor.encode_text_ids(&text_prompt, true)?;
let tokens = processor.encode_text_tokens(&text_prompt, true)?;
let class_ids_map = Self::process_class_ids(&tokens);
let confs_visual = DynConf::new(options.class_confs(), class_names.len());
let confs_textual = DynConf::new(options.text_confs(), class_names.len());
Ok(Self {
engine,
@ -65,6 +73,7 @@ impl GroundingDINO {
ts,
processor,
spec,
class_ids_map,
})
}
@ -73,30 +82,17 @@ impl GroundingDINO {
let image_embeddings = self.processor.process_images(xs)?;
// encode texts
let tokens_f32 = self
.tokens
.iter()
.map(|x| if x == "." { 1. } else { 0. })
.collect::<Vec<_>>();
// input_ids
let input_ids = X::from(self.token_ids.clone())
.insert_axis(0)?
.repeat(0, self.batch)?;
// token_type_ids
let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]);
// attention_mask
let attention_mask = X::ones(&[self.batch, tokens_f32.len()]);
// text_self_attention_masks
let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)?
let token_type_ids = X::zeros(&[self.batch, self.tokens.len()]);
let attention_mask = X::ones(&[self.batch, self.tokens.len()]);
let (text_self_attention_masks, position_ids) =
Self::gen_text_attn_masks_and_pos_ids(&self.token_ids)?;
let text_self_attention_masks = text_self_attention_masks
.insert_axis(0)?
.repeat(0, self.batch)?;
// position_ids
let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?;
let position_ids = position_ids.insert_axis(0)?.repeat(0, self.batch)?;
// inputs
let xs = Xs::from(vec![
@ -135,7 +131,6 @@ impl GroundingDINO {
.filter_map(|(idx, logits)| {
let (image_height, image_width) = self.processor.image0s_size[idx];
let ratio = self.processor.scale_factors_hw[idx][0];
let y_bboxes: Vec<Bbox> = logits
.axis_iter(Axis(0))
.into_par_iter()
@ -161,13 +156,15 @@ impl GroundingDINO {
let x = x.max(0.0).min(image_width as _);
let y = y.max(0.0).min(image_height as _);
Some(
Bbox::default()
.with_xywh(x, y, w, h)
.with_id(class_id as _)
.with_name(&self.class_names[class_id])
.with_confidence(conf),
)
self.class_ids_map[class_id].map(|c| {
let mut bbox =
Bbox::default().with_xywh(x, y, w, h).with_confidence(conf);
if conf > self.confs_textual[c] {
bbox = bbox.with_name(&self.class_names[c]).with_id(c as _);
}
bbox
})
})
.collect();
@ -182,42 +179,69 @@ impl GroundingDINO {
Ok(ys.into())
}
fn parse_texts(texts: &[&str]) -> String {
let mut y = String::new();
for text in texts.iter() {
if !text.is_empty() {
y.push_str(&format!("{} . ", text));
fn gen_text_attn_masks_and_pos_ids(input_ids: &[f32]) -> Result<(X, X)> {
let n = input_ids.len();
let mut vs: Vec<f32> = input_ids
.iter()
.map(|&x| {
if (x - 101.0).abs() < f32::EPSILON
|| (x - 1012.0).abs() < f32::EPSILON
|| (x - 102.0).abs() < f32::EPSILON
{
1.0
} else {
0.0
}
})
.collect();
vs[0] = 1.0;
vs[n - 1] = 1.0;
let special_idxs: Vec<usize> = vs
.iter()
.enumerate()
.filter_map(|(i, &v)| {
if (v - 1.0).abs() < f32::EPSILON {
Some(i)
} else {
None
}
})
.collect();
let mut attn = Array2::<f32>::eye(n);
let mut pos_ids = vec![0f32; n];
let mut prev = special_idxs[0];
for &idx in special_idxs.iter() {
if idx == 0 || idx == n - 1 {
} else {
for r in (prev + 1)..=idx {
for c in (prev + 1)..=idx {
attn[[r, c]] = 1.0;
}
}
y
for (offset, pos_id) in pos_ids[prev + 1..=idx].iter_mut().enumerate() {
*pos_id = offset as f32;
}
}
prev = idx;
}
fn gen_text_self_attention_masks(tokens: &[f32]) -> Result<X> {
let mut vs = tokens.to_vec();
let n = vs.len();
vs[0] = 1.;
vs[n - 1] = 1.;
let mut ys = Array::zeros((n, n)).into_dyn();
let mut i_last = -1;
for (i, &v) in vs.iter().enumerate() {
if v == 0. {
if i_last == -1 {
i_last = i as isize;
Ok((X::from(attn.into_dyn()), X::from(pos_ids)))
}
fn process_class_ids(tokens: &[String]) -> Vec<Option<usize>> {
let mut result = Vec::with_capacity(tokens.len());
let mut idx = 0;
for token in tokens {
if token == "[CLS]" || token == "[SEP]" {
result.push(None);
} else if token == "." {
result.push(None);
idx += 1;
} else {
i_last = -1;
}
} else if v == 1. {
if i_last == -1 {
ys.slice_mut(s![i, i]).fill(1.);
} else {
ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
.fill(1.);
}
i_last = -1;
} else {
continue;
result.push(Some(idx));
}
}
Ok(X::from(ys))
result
}
}