mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Fix errors related to GroundingDino class_names (#71)
This commit is contained in:
@ -20,18 +20,18 @@ struct Args {
|
|||||||
#[argh(
|
#[argh(
|
||||||
option,
|
option,
|
||||||
default = "vec![
|
default = "vec![
|
||||||
String::from(\"person\"),
|
String::from(\"person\"),
|
||||||
String::from(\"hand\"),
|
String::from(\"a hand\"),
|
||||||
String::from(\"shoes\"),
|
String::from(\"a shoe\"),
|
||||||
String::from(\"bus\"),
|
String::from(\"bus\"),
|
||||||
String::from(\"dog\"),
|
String::from(\"dog\"),
|
||||||
String::from(\"cat\"),
|
String::from(\"cat\"),
|
||||||
String::from(\"sign\"),
|
String::from(\"sign\"),
|
||||||
String::from(\"tie\"),
|
String::from(\"tie\"),
|
||||||
String::from(\"monitor\"),
|
String::from(\"monitor\"),
|
||||||
String::from(\"glasses\"),
|
String::from(\"glasses\"),
|
||||||
String::from(\"tree\"),
|
String::from(\"tree\"),
|
||||||
String::from(\"head\"),
|
String::from(\"head\"),
|
||||||
]"
|
]"
|
||||||
)]
|
)]
|
||||||
labels: Vec<String>,
|
labels: Vec<String>,
|
||||||
@ -49,6 +49,8 @@ fn main() -> Result<()> {
|
|||||||
.with_model_dtype(args.dtype.as_str().try_into()?)
|
.with_model_dtype(args.dtype.as_str().try_into()?)
|
||||||
.with_model_device(args.device.as_str().try_into()?)
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
|
.with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
|
||||||
|
.with_class_confs(&[0.25])
|
||||||
|
.with_text_confs(&[0.25])
|
||||||
.commit()?;
|
.commit()?;
|
||||||
|
|
||||||
let mut model = GroundingDINO::new(options)?;
|
let mut model = GroundingDINO::new(options)?;
|
||||||
|
@ -12,8 +12,8 @@ impl crate::Options {
|
|||||||
.with_image_mean(&[0.485, 0.456, 0.406])
|
.with_image_mean(&[0.485, 0.456, 0.406])
|
||||||
.with_image_std(&[0.229, 0.224, 0.225])
|
.with_image_std(&[0.229, 0.224, 0.225])
|
||||||
.with_normalize(true)
|
.with_normalize(true)
|
||||||
.with_class_confs(&[0.4])
|
.with_class_confs(&[0.25])
|
||||||
.with_text_confs(&[0.3])
|
.with_text_confs(&[0.25])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn grounding_dino_tiny() -> Self {
|
pub fn grounding_dino_tiny() -> Self {
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
use aksr::Builder;
|
use aksr::Builder;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use image::DynamicImage;
|
use image::DynamicImage;
|
||||||
use ndarray::{s, Array, Axis};
|
use ndarray::{s, Array2, Axis};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
|
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y};
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ pub struct GroundingDINO {
|
|||||||
confs_visual: DynConf,
|
confs_visual: DynConf,
|
||||||
confs_textual: DynConf,
|
confs_textual: DynConf,
|
||||||
class_names: Vec<String>,
|
class_names: Vec<String>,
|
||||||
|
class_ids_map: Vec<Option<usize>>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
token_ids: Vec<f32>,
|
token_ids: Vec<f32>,
|
||||||
ts: Ts,
|
ts: Ts,
|
||||||
@ -26,7 +28,6 @@ impl GroundingDINO {
|
|||||||
pub fn new(options: Options) -> Result<Self> {
|
pub fn new(options: Options) -> Result<Self> {
|
||||||
let engine = options.to_engine()?;
|
let engine = options.to_engine()?;
|
||||||
let spec = engine.spec().to_string();
|
let spec = engine.spec().to_string();
|
||||||
|
|
||||||
let (batch, height, width, ts) = (
|
let (batch, height, width, ts) = (
|
||||||
engine.batch().opt(),
|
engine.batch().opt(),
|
||||||
engine.try_height().unwrap_or(&800.into()).opt(),
|
engine.try_height().unwrap_or(&800.into()).opt(),
|
||||||
@ -37,20 +38,27 @@ impl GroundingDINO {
|
|||||||
.to_processor()?
|
.to_processor()?
|
||||||
.with_image_width(width as _)
|
.with_image_width(width as _)
|
||||||
.with_image_height(height as _);
|
.with_image_height(height as _);
|
||||||
let confs_visual = DynConf::new(options.class_confs(), 1);
|
let class_names = options
|
||||||
let confs_textual = DynConf::new(options.text_confs(), 1);
|
.text_names
|
||||||
|
.as_ref()
|
||||||
let class_names = Self::parse_texts(
|
.and_then(|v| {
|
||||||
&options
|
let v: Vec<_> = v
|
||||||
.text_names
|
.iter()
|
||||||
.expect("No class names specified!")
|
.map(|s| s.trim().to_ascii_lowercase())
|
||||||
.iter()
|
.filter(|s| !s.is_empty())
|
||||||
.map(|x| x.as_str())
|
.collect();
|
||||||
.collect::<Vec<_>>(),
|
(!v.is_empty()).then_some(v)
|
||||||
);
|
})
|
||||||
let token_ids = processor.encode_text_ids(&class_names, true)?;
|
.ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the options. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
|
||||||
let tokens = processor.encode_text_tokens(&class_names, true)?;
|
let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
|
||||||
let class_names = tokens.clone();
|
write!(&mut acc, "{}.", text).unwrap();
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
let token_ids = processor.encode_text_ids(&text_prompt, true)?;
|
||||||
|
let tokens = processor.encode_text_tokens(&text_prompt, true)?;
|
||||||
|
let class_ids_map = Self::process_class_ids(&tokens);
|
||||||
|
let confs_visual = DynConf::new(options.class_confs(), class_names.len());
|
||||||
|
let confs_textual = DynConf::new(options.text_confs(), class_names.len());
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
engine,
|
engine,
|
||||||
@ -65,6 +73,7 @@ impl GroundingDINO {
|
|||||||
ts,
|
ts,
|
||||||
processor,
|
processor,
|
||||||
spec,
|
spec,
|
||||||
|
class_ids_map,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,30 +82,17 @@ impl GroundingDINO {
|
|||||||
let image_embeddings = self.processor.process_images(xs)?;
|
let image_embeddings = self.processor.process_images(xs)?;
|
||||||
|
|
||||||
// encode texts
|
// encode texts
|
||||||
let tokens_f32 = self
|
|
||||||
.tokens
|
|
||||||
.iter()
|
|
||||||
.map(|x| if x == "." { 1. } else { 0. })
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// input_ids
|
|
||||||
let input_ids = X::from(self.token_ids.clone())
|
let input_ids = X::from(self.token_ids.clone())
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?;
|
.repeat(0, self.batch)?;
|
||||||
|
let token_type_ids = X::zeros(&[self.batch, self.tokens.len()]);
|
||||||
// token_type_ids
|
let attention_mask = X::ones(&[self.batch, self.tokens.len()]);
|
||||||
let token_type_ids = X::zeros(&[self.batch, tokens_f32.len()]);
|
let (text_self_attention_masks, position_ids) =
|
||||||
|
Self::gen_text_attn_masks_and_pos_ids(&self.token_ids)?;
|
||||||
// attention_mask
|
let text_self_attention_masks = text_self_attention_masks
|
||||||
let attention_mask = X::ones(&[self.batch, tokens_f32.len()]);
|
|
||||||
|
|
||||||
// text_self_attention_masks
|
|
||||||
let text_self_attention_masks = Self::gen_text_self_attention_masks(&tokens_f32)?
|
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?;
|
.repeat(0, self.batch)?;
|
||||||
|
let position_ids = position_ids.insert_axis(0)?.repeat(0, self.batch)?;
|
||||||
// position_ids
|
|
||||||
let position_ids = X::from(tokens_f32).insert_axis(0)?.repeat(0, self.batch)?;
|
|
||||||
|
|
||||||
// inputs
|
// inputs
|
||||||
let xs = Xs::from(vec![
|
let xs = Xs::from(vec![
|
||||||
@ -135,7 +131,6 @@ impl GroundingDINO {
|
|||||||
.filter_map(|(idx, logits)| {
|
.filter_map(|(idx, logits)| {
|
||||||
let (image_height, image_width) = self.processor.image0s_size[idx];
|
let (image_height, image_width) = self.processor.image0s_size[idx];
|
||||||
let ratio = self.processor.scale_factors_hw[idx][0];
|
let ratio = self.processor.scale_factors_hw[idx][0];
|
||||||
|
|
||||||
let y_bboxes: Vec<Bbox> = logits
|
let y_bboxes: Vec<Bbox> = logits
|
||||||
.axis_iter(Axis(0))
|
.axis_iter(Axis(0))
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
@ -161,13 +156,15 @@ impl GroundingDINO {
|
|||||||
let x = x.max(0.0).min(image_width as _);
|
let x = x.max(0.0).min(image_width as _);
|
||||||
let y = y.max(0.0).min(image_height as _);
|
let y = y.max(0.0).min(image_height as _);
|
||||||
|
|
||||||
Some(
|
self.class_ids_map[class_id].map(|c| {
|
||||||
Bbox::default()
|
let mut bbox =
|
||||||
.with_xywh(x, y, w, h)
|
Bbox::default().with_xywh(x, y, w, h).with_confidence(conf);
|
||||||
.with_id(class_id as _)
|
|
||||||
.with_name(&self.class_names[class_id])
|
if conf > self.confs_textual[c] {
|
||||||
.with_confidence(conf),
|
bbox = bbox.with_name(&self.class_names[c]).with_id(c as _);
|
||||||
)
|
}
|
||||||
|
bbox
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -182,42 +179,69 @@ impl GroundingDINO {
|
|||||||
Ok(ys.into())
|
Ok(ys.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_texts(texts: &[&str]) -> String {
|
fn gen_text_attn_masks_and_pos_ids(input_ids: &[f32]) -> Result<(X, X)> {
|
||||||
let mut y = String::new();
|
let n = input_ids.len();
|
||||||
for text in texts.iter() {
|
let mut vs: Vec<f32> = input_ids
|
||||||
if !text.is_empty() {
|
.iter()
|
||||||
y.push_str(&format!("{} . ", text));
|
.map(|&x| {
|
||||||
|
if (x - 101.0).abs() < f32::EPSILON
|
||||||
|
|| (x - 1012.0).abs() < f32::EPSILON
|
||||||
|
|| (x - 102.0).abs() < f32::EPSILON
|
||||||
|
{
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
vs[0] = 1.0;
|
||||||
|
vs[n - 1] = 1.0;
|
||||||
|
|
||||||
|
let special_idxs: Vec<usize> = vs
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, &v)| {
|
||||||
|
if (v - 1.0).abs() < f32::EPSILON {
|
||||||
|
Some(i)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let mut attn = Array2::<f32>::eye(n);
|
||||||
|
let mut pos_ids = vec![0f32; n];
|
||||||
|
let mut prev = special_idxs[0];
|
||||||
|
for &idx in special_idxs.iter() {
|
||||||
|
if idx == 0 || idx == n - 1 {
|
||||||
|
} else {
|
||||||
|
for r in (prev + 1)..=idx {
|
||||||
|
for c in (prev + 1)..=idx {
|
||||||
|
attn[[r, c]] = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (offset, pos_id) in pos_ids[prev + 1..=idx].iter_mut().enumerate() {
|
||||||
|
*pos_id = offset as f32;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
prev = idx;
|
||||||
}
|
}
|
||||||
y
|
|
||||||
|
Ok((X::from(attn.into_dyn()), X::from(pos_ids)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gen_text_self_attention_masks(tokens: &[f32]) -> Result<X> {
|
fn process_class_ids(tokens: &[String]) -> Vec<Option<usize>> {
|
||||||
let mut vs = tokens.to_vec();
|
let mut result = Vec::with_capacity(tokens.len());
|
||||||
let n = vs.len();
|
let mut idx = 0;
|
||||||
vs[0] = 1.;
|
for token in tokens {
|
||||||
vs[n - 1] = 1.;
|
if token == "[CLS]" || token == "[SEP]" {
|
||||||
let mut ys = Array::zeros((n, n)).into_dyn();
|
result.push(None);
|
||||||
let mut i_last = -1;
|
} else if token == "." {
|
||||||
for (i, &v) in vs.iter().enumerate() {
|
result.push(None);
|
||||||
if v == 0. {
|
idx += 1;
|
||||||
if i_last == -1 {
|
|
||||||
i_last = i as isize;
|
|
||||||
} else {
|
|
||||||
i_last = -1;
|
|
||||||
}
|
|
||||||
} else if v == 1. {
|
|
||||||
if i_last == -1 {
|
|
||||||
ys.slice_mut(s![i, i]).fill(1.);
|
|
||||||
} else {
|
|
||||||
ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
|
|
||||||
.fill(1.);
|
|
||||||
}
|
|
||||||
i_last = -1;
|
|
||||||
} else {
|
} else {
|
||||||
continue;
|
result.push(Some(idx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(X::from(ys))
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user