mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Fix errors related to GroundingDino class_names (#71)
This commit is contained in:
@ -21,8 +21,8 @@ struct Args {
|
||||
option,
|
||||
default = "vec![
|
||||
String::from(\"person\"),
|
||||
String::from(\"hand\"),
|
||||
String::from(\"shoes\"),
|
||||
String::from(\"a hand\"),
|
||||
String::from(\"a shoe\"),
|
||||
String::from(\"bus\"),
|
||||
String::from(\"dog\"),
|
||||
String::from(\"cat\"),
|
||||
@ -49,6 +49,8 @@ fn main() -> Result<()> {
|
||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
|
||||
.with_class_confs(&[0.25])
|
||||
.with_text_confs(&[0.25])
|
||||
.commit()?;
|
||||
|
||||
let mut model = GroundingDINO::new(options)?;
|
||||
|
@ -12,8 +12,8 @@ impl crate::Options {
|
||||
.with_image_mean(&[0.485, 0.456, 0.406])
|
||||
.with_image_std(&[0.229, 0.224, 0.225])
|
||||
.with_normalize(true)
|
||||
.with_class_confs(&[0.4])
|
||||
.with_text_confs(&[0.3])
|
||||
.with_class_confs(&[0.25])
|
||||
.with_text_confs(&[0.25])
|
||||
}
|
||||
|
||||
pub fn grounding_dino_tiny() -> Self {
|
||||
|
@ -1,8 +1,9 @@
|
||||
use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use ndarray::{s, Array, Axis};
|
||||
use ndarray::{s, Array2, Axis};
|
||||
use rayon::prelude::*;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
|
||||
|
||||
@ -15,6 +16,7 @@ pub struct GroundingDINO {
|
||||
confs_visual: DynConf,
|
||||
confs_textual: DynConf,
|
||||
class_names: Vec<String>,
|
||||
class_ids_map: Vec<Option<usize>>,
|
||||
tokens: Vec<String>,
|
||||
token_ids: Vec<f32>,
|
||||
ts: Ts,
|
||||
@ -26,7 +28,6 @@ impl GroundingDINO {
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let engine = options.to_engine()?;
|
||||
let spec = engine.spec().to_string();
|
||||
|
||||
let (batch, height, width, ts) = (
|
||||
engine.batch().opt(),
|
||||
engine.try_height().unwrap_or(&800.into()).opt(),
|
||||
@ -37,20 +38,27 @@ impl GroundingDINO {
|
||||
.to_processor()?
|
||||
.with_image_width(width as _)
|
||||
.with_image_height(height as _);
|
||||
let confs_visual = DynConf::new(options.class_confs(), 1);
|
||||
let confs_textual = DynConf::new(options.text_confs(), 1);
|
||||
|
||||
let class_names = Self::parse_texts(
|
||||
&options
|
||||
.text_names
|
||||
.expect("No class names specified!")
|
||||
.iter()
|
||||
.map(|x| x.as_str())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let token_ids = processor.encode_text_ids(&class_names, true)?;
|
||||
let tokens = processor.encode_text_tokens(&class_names, true)?;
|
||||
let class_names = tokens.clone();
|
||||
let class_names = options
|
||||
.text_names
|
||||
.as_ref()
|
||||
.and_then(|v| {
|
||||
let v: Vec<_> = v
|
||||
.iter()
|
||||
.map(|s| s.trim().to_ascii_lowercase())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
(!v.is_empty()).then_some(v)
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
|
||||
let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
|
||||
write!(&mut acc, "{}.", text).unwrap();
|
||||
acc
|
||||
});
|
||||
let token_ids = processor.encode_text_ids(&text_prompt, true)?;
|
||||
let tokens = processor.encode_text_tokens(&text_prompt, true)?;
|
||||
let class_ids_map = Self::process_class_ids(&tokens);
|
||||
let confs_visual = DynConf::new(options.class_confs(), class_names.len());
|
||||
let confs_textual = DynConf::new(options.text_confs(), class_names.len());
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
@ -65,6 +73,7 @@ impl GroundingDINO {
|
||||
ts,
|
||||
processor,
|
||||
spec,
|
||||
class_ids_map,
|
||||
})
|
||||
}
|
||||
|
||||
@ -73,30 +82,17 @@ impl GroundingDINO {
|
||||
let image_embeddings = self.processor.process_images(xs)?;
|
||||
|
||||
// encode texts
|
||||
let tokens_f32 = self
|
||||
.tokens
|
||||
.iter()
|
||||
.map(|x| if x == "." { 1. } else { 0. })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// input_ids
|
||||
let input_ids = X::from(self.token_ids.clone())
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?;
|
||||
|
||||
// token_type_ids
|
||||
let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]);
|
||||
|
||||
// attention_mask
|
||||
let attention_mask = X::ones(&[self.batch, tokens_f32.len()]);
|
||||
|
||||
// text_self_attention_masks
|
||||
let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)?
|
||||
let token_type_ids = X::zeros(&[self.batch, self.tokens.len()]);
|
||||
let attention_mask = X::ones(&[self.batch, self.tokens.len()]);
|
||||
let (text_self_attention_masks, position_ids) =
|
||||
Self::gen_text_attn_masks_and_pos_ids(&self.token_ids)?;
|
||||
let text_self_attention_masks = text_self_attention_masks
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?;
|
||||
|
||||
// position_ids
|
||||
let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?;
|
||||
let position_ids = position_ids.insert_axis(0)?.repeat(0, self.batch)?;
|
||||
|
||||
// inputs
|
||||
let xs = Xs::from(vec![
|
||||
@ -135,7 +131,6 @@ impl GroundingDINO {
|
||||
.filter_map(|(idx, logits)| {
|
||||
let (image_height, image_width) = self.processor.image0s_size[idx];
|
||||
let ratio = self.processor.scale_factors_hw[idx][0];
|
||||
|
||||
let y_bboxes: Vec<Bbox> = logits
|
||||
.axis_iter(Axis(0))
|
||||
.into_par_iter()
|
||||
@ -161,13 +156,15 @@ impl GroundingDINO {
|
||||
let x = x.max(0.0).min(image_width as _);
|
||||
let y = y.max(0.0).min(image_height as _);
|
||||
|
||||
Some(
|
||||
Bbox::default()
|
||||
.with_xywh(x, y, w, h)
|
||||
.with_id(class_id as _)
|
||||
.with_name(&self.class_names[class_id])
|
||||
.with_confidence(conf),
|
||||
)
|
||||
self.class_ids_map[class_id].map(|c| {
|
||||
let mut bbox =
|
||||
Bbox::default().with_xywh(x, y, w, h).with_confidence(conf);
|
||||
|
||||
if conf > self.confs_textual[c] {
|
||||
bbox = bbox.with_name(&self.class_names[c]).with_id(c as _);
|
||||
}
|
||||
bbox
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@ -182,42 +179,69 @@ impl GroundingDINO {
|
||||
Ok(ys.into())
|
||||
}
|
||||
|
||||
fn parse_texts(texts: &[&str]) -> String {
|
||||
let mut y = String::new();
|
||||
for text in texts.iter() {
|
||||
if !text.is_empty() {
|
||||
y.push_str(&format!("{} . ", text));
|
||||
fn gen_text_attn_masks_and_pos_ids(input_ids: &[f32]) -> Result<(X, X)> {
|
||||
let n = input_ids.len();
|
||||
let mut vs: Vec<f32> = input_ids
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
if (x - 101.0).abs() < f32::EPSILON
|
||||
|| (x - 1012.0).abs() < f32::EPSILON
|
||||
|| (x - 102.0).abs() < f32::EPSILON
|
||||
{
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
vs[0] = 1.0;
|
||||
vs[n - 1] = 1.0;
|
||||
|
||||
let special_idxs: Vec<usize> = vs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, &v)| {
|
||||
if (v - 1.0).abs() < f32::EPSILON {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut attn = Array2::<f32>::eye(n);
|
||||
let mut pos_ids = vec![0f32; n];
|
||||
let mut prev = special_idxs[0];
|
||||
for &idx in special_idxs.iter() {
|
||||
if idx == 0 || idx == n - 1 {
|
||||
} else {
|
||||
for r in (prev + 1)..=idx {
|
||||
for c in (prev + 1)..=idx {
|
||||
attn[[r, c]] = 1.0;
|
||||
}
|
||||
}
|
||||
for (offset, pos_id) in pos_ids[prev + 1..=idx].iter_mut().enumerate() {
|
||||
*pos_id = offset as f32;
|
||||
}
|
||||
}
|
||||
prev = idx;
|
||||
}
|
||||
y
|
||||
|
||||
Ok((X::from(attn.into_dyn()), X::from(pos_ids)))
|
||||
}
|
||||
|
||||
fn gen_text_self_attention_masks(tokens: &[f32]) -> Result<X> {
|
||||
let mut vs = tokens.to_vec();
|
||||
let n = vs.len();
|
||||
vs[0] = 1.;
|
||||
vs[n - 1] = 1.;
|
||||
let mut ys = Array::zeros((n, n)).into_dyn();
|
||||
let mut i_last = -1;
|
||||
for (i, &v) in vs.iter().enumerate() {
|
||||
if v == 0. {
|
||||
if i_last == -1 {
|
||||
i_last = i as isize;
|
||||
} else {
|
||||
i_last = -1;
|
||||
}
|
||||
} else if v == 1. {
|
||||
if i_last == -1 {
|
||||
ys.slice_mut(s![i, i]).fill(1.);
|
||||
} else {
|
||||
ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
|
||||
.fill(1.);
|
||||
}
|
||||
i_last = -1;
|
||||
fn process_class_ids(tokens: &[String]) -> Vec<Option<usize>> {
|
||||
let mut result = Vec::with_capacity(tokens.len());
|
||||
let mut idx = 0;
|
||||
for token in tokens {
|
||||
if token == "[CLS]" || token == "[SEP]" {
|
||||
result.push(None);
|
||||
} else if token == "." {
|
||||
result.push(None);
|
||||
idx += 1;
|
||||
} else {
|
||||
continue;
|
||||
result.push(Some(idx));
|
||||
}
|
||||
}
|
||||
Ok(X::from(ys))
|
||||
result
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user