diff --git a/Cargo.toml b/Cargo.toml
index 892f3f2..3c5f59e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "usls"
-version = "0.0.9"
+version = "0.0.10"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
diff --git a/README.md b/README.md
index 2d0da85..9b230ac 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
-A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [DINOv2](https://github.com/facebookresearch/dinov2), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) and others.
+A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [DINOv2](https://github.com/facebookresearch/dinov2), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) and others.
| Segment Anything |
@@ -55,6 +55,7 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
| [Depth-Anything
(v1, v2)](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
+| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
## Installation
diff --git a/examples/blip/main.rs b/examples/blip/main.rs
index 51135f9..f1a8486 100644
--- a/examples/blip/main.rs
+++ b/examples/blip/main.rs
@@ -10,7 +10,7 @@ fn main() -> Result<(), Box> {
// textual
let options_textual = Options::default()
.with_model("blip-textual-base.onnx")?
- .with_tokenizer("tokenizer-blip.json")?
+ // .with_tokenizer("tokenizer-blip.json")?
.with_i00((1, 1, 4).into()) // input_id: batch
.with_i01((1, 1, 4).into()) // input_id: seq_len
.with_i10((1, 1, 4).into()) // attention_mask: batch
@@ -23,9 +23,10 @@ fn main() -> Result<(), Box> {
let mut model = Blip::new(options_visual, options_textual)?;
// image caption (this demo use batch_size=1)
- let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
- let _y = model.caption(&x, None, true)?; // unconditional
- let y = model.caption(&x, Some("three man"), true)?; // conditional
+ let xs = vec![DataLoader::try_read("./assets/bus.jpg")?];
+ let image_embeddings = model.encode_images(&xs)?;
+ let _y = model.caption(&image_embeddings, None, true)?; // unconditional
+ let y = model.caption(&image_embeddings, Some("three man"), true)?; // conditional
println!("{:?}", y[0].texts());
Ok(())
diff --git a/examples/clip/main.rs b/examples/clip/main.rs
index e2bfc1f..4873b2f 100644
--- a/examples/clip/main.rs
+++ b/examples/clip/main.rs
@@ -10,7 +10,7 @@ fn main() -> Result<(), Box> {
// textual
let options_textual = Options::default()
.with_model("clip-b32-textual-dyn.onnx")?
- .with_tokenizer("tokenizer-clip.json")?
+ // .with_tokenizer("tokenizer-clip.json")?
.with_i00((1, 1, 4).into())
.with_profile(false);
diff --git a/examples/grounding-dino/main.rs b/examples/grounding-dino/main.rs
new file mode 100644
index 0000000..0269f18
--- /dev/null
+++ b/examples/grounding-dino/main.rs
@@ -0,0 +1,40 @@
+use usls::{models::GroundingDINO, Annotator, DataLoader, Options};
+
+fn main() -> Result<(), Box> {
+ let opts = Options::default()
+ .with_i00((1, 1, 4).into())
+ .with_i02((640, 800, 1200).into())
+ .with_i03((640, 1200, 1200).into())
+ .with_i10((1, 1, 4).into())
+ .with_i11((256, 256, 512).into())
+ .with_i20((1, 1, 4).into())
+ .with_i21((256, 256, 512).into())
+ .with_i30((1, 1, 4).into())
+ .with_i31((256, 256, 512).into())
+ .with_i40((1, 1, 4).into())
+ .with_i41((256, 256, 512).into())
+ .with_i50((1, 1, 4).into())
+ .with_i51((256, 256, 512).into())
+ .with_i52((256, 256, 512).into())
+ .with_model("groundingdino-swint-ogc-dyn-u8.onnx")? // TODO: current onnx model does not support bs > 1
+ // .with_model("groundingdino-swint-ogc-dyn-f32.onnx")?
+ .with_confs(&[0.2])
+ .with_profile(false);
+ let mut model = GroundingDINO::new(opts)?;
+
+ // Load images and set class names
+ let x = [DataLoader::try_read("./assets/bus.jpg")?];
+ let texts = [
+ "person", "hand", "shoes", "bus", "dog", "cat", "sign", "tie", "monitor", "window",
+ "glasses", "tree", "head",
+ ];
+
+ // Run and annotate
+ let y = model.run(&x, &texts)?;
+ let annotator = Annotator::default()
+ .with_bboxes_thickness(4)
+ .with_saveout("GroundingDINO");
+ annotator.annotate(&x, &y);
+
+ Ok(())
+}
diff --git a/examples/sam/main.rs b/examples/sam/main.rs
index f884fae..41b04c9 100644
--- a/examples/sam/main.rs
+++ b/examples/sam/main.rs
@@ -99,7 +99,10 @@ fn main() -> Result<(), Box> {
let mut model = SAM::new(options_encoder, options_decoder)?;
// Load image
- let xs = vec![DataLoader::try_read("./assets/truck.jpg")?];
+ let xs = [
+ DataLoader::try_read("./assets/truck.jpg")?,
+ // DataLoader::try_read("./assets/dog.jpg")?,
+ ];
// Build annotator
let annotator = Annotator::default().with_saveout(saveout);
diff --git a/src/core/ops.rs b/src/core/ops.rs
index bc29e3e..6b5b94f 100644
--- a/src/core/ops.rs
+++ b/src/core/ops.rs
@@ -7,7 +7,7 @@ use fast_image_resize::{
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
-use ndarray::{s, Array, Axis, IxDyn};
+use ndarray::{s, Array, Axis, IntoDimension, IxDyn};
use rayon::prelude::*;
pub enum Ops<'a> {
@@ -20,6 +20,10 @@ pub enum Ops<'a> {
Nhwc2nchw,
Nchw2nhwc,
Norm,
+ Sigmoid,
+ Broadcast,
+ ToShape,
+ Repeat,
}
impl Ops<'_> {
@@ -34,6 +38,41 @@ impl Ops<'_> {
Ok((x - min) / (max - min))
}
+ pub fn sigmoid(x: Array) -> Array {
+ x.mapv(|x| 1. / ((-x).exp() + 1.))
+ }
+
+ pub fn broadcast(
+ x: Array,
+ dim: D,
+ ) -> Result> {
+ match x.broadcast(dim) {
+ Some(x) => Ok(x.to_owned().into_dyn()),
+ None => anyhow::bail!(
+ "Failed to broadcast. Shape: {:?}, dim: {:?}",
+ x.shape(),
+ dim
+ ),
+ }
+ }
+
+ pub fn repeat(x: Array, d: usize, n: usize) -> Result> {
+ if d >= x.ndim() {
+ anyhow::bail!("Index {d} is out of bounds with size {}.", x.ndim());
+ } else {
+ let mut dim = x.shape().to_vec();
+ dim[d] = n;
+ Self::broadcast(x, dim.as_slice())
+ }
+ }
+
+ pub fn to_shape(
+ x: Array,
+ dim: D,
+ ) -> Result> {
+ Ok(x.to_shape(dim).map(|x| x.to_owned().into_dyn())?)
+ }
+
pub fn standardize(
x: Array,
mean: &[f32],
diff --git a/src/core/options.rs b/src/core/options.rs
index dc4b00a..0b11113 100644
--- a/src/core/options.rs
+++ b/src/core/options.rs
@@ -73,10 +73,13 @@ pub struct Options {
pub nk: Option,
pub nm: Option,
pub confs: Vec,
+ pub confs2: Vec,
+ pub confs3: Vec,
pub kconfs: Vec,
pub iou: Option,
pub tokenizer: Option,
pub vocab: Option,
+ pub context_length: Option,
pub names: Option>, // names
pub names2: Option>, // names2
pub names3: Option>, // names3
@@ -152,11 +155,14 @@ impl Default for Options {
nc: None,
nk: None,
nm: None,
- confs: vec![0.4f32],
+ confs: vec![0.3f32],
+ confs2: vec![0.3f32],
+ confs3: vec![0.3f32],
kconfs: vec![0.5f32],
iou: None,
tokenizer: None,
vocab: None,
+ context_length: None,
names: None,
names2: None,
names3: None,
@@ -255,12 +261,17 @@ impl Options {
}
pub fn with_vocab(mut self, vocab: &str) -> Result {
- self.vocab = Some(auto_load(vocab, Some("models"))?);
+ self.vocab = Some(auto_load(vocab, Some("tokenizers"))?);
Ok(self)
}
+ pub fn with_context_length(mut self, n: usize) -> Self {
+ self.context_length = Some(n);
+ self
+ }
+
pub fn with_tokenizer(mut self, tokenizer: &str) -> Result {
- self.tokenizer = Some(auto_load(tokenizer, Some("models"))?);
+ self.tokenizer = Some(auto_load(tokenizer, Some("tokenizers"))?);
Ok(self)
}
@@ -299,8 +310,18 @@ impl Options {
self
}
- pub fn with_confs(mut self, confs: &[f32]) -> Self {
- self.confs = confs.to_vec();
+ pub fn with_confs(mut self, x: &[f32]) -> Self {
+ self.confs = x.to_vec();
+ self
+ }
+
+ pub fn with_confs2(mut self, x: &[f32]) -> Self {
+ self.confs2 = x.to_vec();
+ self
+ }
+
+ pub fn with_confs3(mut self, x: &[f32]) -> Self {
+ self.confs3 = x.to_vec();
self
}
diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs
index a30c208..a09aeeb 100644
--- a/src/core/ort_engine.rs
+++ b/src/core/ort_engine.rs
@@ -321,6 +321,9 @@ impl OrtEngine {
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
+ TensorElementType::Bool => {
+ ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
+ }
_ => todo!(),
};
xs_.push(Into::>::into(x_));
diff --git a/src/core/x.rs b/src/core/x.rs
index beb026b..433479d 100644
--- a/src/core/x.rs
+++ b/src/core/x.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use image::DynamicImage;
-use ndarray::{Array, Dim, IxDyn, IxDynImpl};
+use ndarray::{Array, Dim, IntoDimension, IxDyn, IxDynImpl};
use crate::Ops;
@@ -51,12 +51,28 @@ impl X {
Ops::InsertAxis(d) => y.insert_axis(*d)?,
Ops::Nhwc2nchw => y.nhwc2nchw()?,
Ops::Nchw2nhwc => y.nchw2nhwc()?,
+ Ops::Sigmoid => y.sigmoid()?,
_ => todo!(),
}
}
Ok(y)
}
+ pub fn sigmoid(mut self) -> Result {
+ self.0 = Ops::sigmoid(self.0);
+ Ok(self)
+ }
+
+ pub fn broadcast(mut self, dim: D) -> Result {
+ self.0 = Ops::broadcast(self.0, dim)?;
+ Ok(self)
+ }
+
+ pub fn to_shape(mut self, dim: D) -> Result {
+ self.0 = Ops::to_shape(self.0, dim)?;
+ Ok(self)
+ }
+
pub fn permute(mut self, shape: &[usize]) -> Result {
self.0 = Ops::permute(self.0, shape)?;
Ok(self)
@@ -77,6 +93,11 @@ impl X {
Ok(self)
}
+ pub fn repeat(mut self, d: usize, n: usize) -> Result {
+ self.0 = Ops::repeat(self.0, d, n)?;
+ Ok(self)
+ }
+
pub fn dims(&self) -> &[usize] {
self.0.shape()
}
diff --git a/src/models/blip.rs b/src/models/blip.rs
index 15e36d3..268aa32 100644
--- a/src/models/blip.rs
+++ b/src/models/blip.rs
@@ -1,11 +1,12 @@
use anyhow::Result;
use image::DynamicImage;
-use ndarray::{s, Array, Axis, IxDyn};
+use ndarray::s;
use std::io::Write;
use tokenizers::Tokenizer;
use crate::{
- Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y,
+ auto_load, Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs,
+ X, Y,
};
#[derive(Debug)]
@@ -29,7 +30,19 @@ impl Blip {
visual.height().to_owned(),
visual.width().to_owned(),
);
- let tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
+
+ let tokenizer = match options_textual.tokenizer {
+ Some(x) => x,
+ None => match auto_load("tokenizer-blip.json", Some("tokenizers")) {
+ Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
+ Ok(x) => x,
+ },
+ };
+ let tokenizer = match Tokenizer::from_file(tokenizer) {
+ Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
+ Ok(x) => x,
+ };
+
let tokenizer = TokenizerStream::new(tokenizer);
visual.dry_run()?;
textual.dry_run()?;
@@ -64,17 +77,14 @@ impl Blip {
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}
- pub fn caption(
- &mut self,
- x: &[DynamicImage],
- prompt: Option<&str>,
- show: bool,
- ) -> Result> {
+ pub fn caption(&mut self, xs: &Y, prompt: Option<&str>, show: bool) -> Result> {
let mut ys: Vec = Vec::new();
- let image_embeds = self.encode_images(x)?;
- let image_embeds = image_embeds.embedding().unwrap();
- let image_embeds_attn_mask: Array =
- Array::ones((1, image_embeds.data().shape()[1])).into_dyn();
+ let image_embeds = match xs.embedding() {
+ Some(x) => X::from(x.data().to_owned()),
+ None => anyhow::bail!("No image embeddings found."),
+ };
+ let image_embeds_attn_mask = X::ones(&[self.batch_visual(), image_embeds.dims()[1]]);
+
let mut y_text = String::new();
// conditional
@@ -86,13 +96,11 @@ impl Blip {
vec![0.0f32]
}
Some(prompt) => {
- let encodings = self.tokenizer.tokenizer().encode(prompt, false);
- let ids: Vec = encodings
- .unwrap()
- .get_ids()
- .iter()
- .map(|x| *x as f32)
- .collect();
+ let encodings = match self.tokenizer.tokenizer().encode(prompt, false) {
+ Err(err) => anyhow::bail!("{}", err),
+ Ok(x) => x,
+ };
+ let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect();
if show {
print!("[Conditional]: {} ", prompt);
}
@@ -103,18 +111,16 @@ impl Blip {
let mut logits_sampler = LogitsSampler::new();
loop {
- let input_ids_nd: Array = Array::from_vec(input_ids.to_owned()).into_dyn();
- let input_ids_nd = input_ids_nd.insert_axis(Axis(0));
- let input_ids_nd = X::from(input_ids_nd);
- let input_ids_attn_mask: Array =
- Array::ones(input_ids_nd.shape()).into_dyn();
- let input_ids_attn_mask = X::from(input_ids_attn_mask);
+ let input_ids_nd = X::from(input_ids.to_owned())
+ .insert_axis(0)?
+ .repeat(0, self.batch_textual())?;
+ let input_ids_attn_mask = X::ones(input_ids_nd.dims());
let y = self.textual.run(Xs::from(vec![
input_ids_nd,
input_ids_attn_mask,
- X::from(image_embeds.data().to_owned()),
- X::from(image_embeds_attn_mask.to_owned()),
+ image_embeds.clone(),
+ image_embeds_attn_mask.clone(),
]))?; // N, length, vocab_size
let y = y[0].slice(s!(0, -1.., ..));
let logits = y.slice(s!(0, ..)).to_vec();
diff --git a/src/models/clip.rs b/src/models/clip.rs
index 14a1392..c1e848c 100644
--- a/src/models/clip.rs
+++ b/src/models/clip.rs
@@ -3,7 +3,7 @@ use image::DynamicImage;
use ndarray::Array2;
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
-use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
+use crate::{auto_load, Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
#[derive(Debug)]
pub struct Clip {
@@ -28,7 +28,19 @@ impl Clip {
visual.inputs_minoptmax()[0][2].to_owned(),
visual.inputs_minoptmax()[0][3].to_owned(),
);
- let mut tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
+
+ let tokenizer = match options_textual.tokenizer {
+ Some(x) => x,
+ None => match auto_load("tokenizer-clip.json", Some("tokenizers")) {
+ Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
+ Ok(x) => x,
+ },
+ };
+ let mut tokenizer = match Tokenizer::from_file(tokenizer) {
+ Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
+ Ok(x) => x,
+ };
+
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(context_length),
direction: PaddingDirection::Right,
@@ -74,10 +86,10 @@ impl Clip {
}
pub fn encode_texts(&mut self, texts: &[String]) -> Result {
- let encodings = self
- .tokenizer
- .encode_batch(texts.to_owned(), false)
- .unwrap();
+ let encodings = match self.tokenizer.encode_batch(texts.to_owned(), false) {
+ Err(err) => anyhow::bail!("{:?}", err),
+ Ok(x) => x,
+ };
let xs: Vec = encodings
.iter()
.flat_map(|i| i.get_ids().iter().map(|&b| b as f32))
diff --git a/src/models/grounding_dino.rs b/src/models/grounding_dino.rs
new file mode 100644
index 0000000..58cfc09
--- /dev/null
+++ b/src/models/grounding_dino.rs
@@ -0,0 +1,249 @@
+use crate::{auto_load, Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
+use anyhow::Result;
+use image::DynamicImage;
+use ndarray::{s, Array, Axis};
+use rayon::prelude::*;
+use tokenizers::{Encoding, Tokenizer};
+
+#[derive(Debug)]
+pub struct GroundingDINO {
+ pub engine: OrtEngine,
+ height: MinOptMax,
+ width: MinOptMax,
+ batch: MinOptMax,
+ tokenizer: Tokenizer,
+ pub context_length: usize,
+ confs_visual: DynConf,
+ confs_textual: DynConf,
+}
+
+impl GroundingDINO {
+ pub fn new(options: Options) -> Result {
+ let mut engine = OrtEngine::new(&options)?;
+ let (batch, height, width) = (
+ engine.inputs_minoptmax()[0][0].to_owned(),
+ engine.inputs_minoptmax()[0][2].to_owned(),
+ engine.inputs_minoptmax()[0][3].to_owned(),
+ );
+ let context_length = options.context_length.unwrap_or(256);
+ // let special_tokens = ["[CLS]", "[SEP]", ".", "?"];
+ let tokenizer = match options.tokenizer {
+ Some(x) => x,
+ None => match auto_load("tokenizer-groundingdino.json", Some("tokenizers")) {
+ Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
+ Ok(x) => x,
+ },
+ };
+ let tokenizer = match Tokenizer::from_file(tokenizer) {
+ Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
+ Ok(x) => x,
+ };
+ let confs_visual = DynConf::new(&options.confs, 1);
+ let confs_textual = DynConf::new(&options.confs, 1);
+
+ engine.dry_run()?;
+
+ Ok(Self {
+ engine,
+ batch,
+ height,
+ width,
+ tokenizer,
+ context_length,
+ confs_visual,
+ confs_textual,
+ })
+ }
+
+ pub fn run(&mut self, xs: &[DynamicImage], texts: &[&str]) -> Result> {
+ // image embeddings
+ let image_embeddings = X::apply(&[
+ Ops::Letterbox(
+ xs,
+ self.height() as u32,
+ self.width() as u32,
+ "CatmullRom",
+ 114,
+ "auto",
+ false,
+ ),
+ Ops::Normalize(0., 255.),
+ Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
+ Ops::Nhwc2nchw,
+ ])?;
+
+ // encoding
+ let text = Self::parse_texts(texts);
+ let encoding = match self.tokenizer.encode(text, true) {
+ Err(err) => anyhow::bail!("{}", err),
+ Ok(x) => x,
+ };
+ let tokens = encoding.get_tokens();
+
+ // input_ids
+ let input_ids = X::from(
+ encoding
+ .get_ids()
+ .iter()
+ .map(|&x| x as f32)
+ .collect::>(),
+ )
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?;
+
+ // token_type_ids
+ let token_type_ids = X::zeros(&[self.batch() as usize, tokens.len()]);
+
+ // attention_mask
+ let attention_mask = X::ones(&[self.batch() as usize, tokens.len()]);
+
+ // position_ids
+ let position_ids = X::from(
+ encoding
+ .get_tokens()
+ .iter()
+ .map(|x| if x == "." { 1. } else { 0. })
+ .collect::>(),
+ )
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?;
+
+ // text_self_attention_masks
+ let text_self_attention_masks = Self::gen_text_self_attention_masks(&encoding)?
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?;
+
+ // run
+ let ys = self.engine.run(Xs::from(vec![
+ image_embeddings,
+ input_ids,
+ attention_mask,
+ position_ids,
+ token_type_ids,
+ text_self_attention_masks,
+ ]))?;
+
+ // post-process
+ self.postprocess(ys, xs, tokens)
+ }
+
+ fn postprocess(&self, xs: Xs, xs0: &[DynamicImage], tokens: &[String]) -> Result> {
+ let ys: Vec = xs["logits"]
+ .axis_iter(Axis(0))
+ .into_par_iter()
+ .enumerate()
+ .filter_map(|(idx, logits)| {
+ let image_width = xs0[idx].width() as f32;
+ let image_height = xs0[idx].height() as f32;
+ let ratio =
+ (self.width() as f32 / image_width).min(self.height() as f32 / image_height);
+
+ let y_bboxes: Vec = logits
+ .axis_iter(Axis(0))
+ .into_par_iter()
+ .enumerate()
+ .filter_map(|(i, clss)| {
+ let (class_id, &conf) = clss
+ .mapv(|x| 1. / ((-x).exp() + 1.))
+ .iter()
+ .enumerate()
+ .max_by(|a, b| a.1.total_cmp(b.1))?;
+
+ if conf < self.conf_visual() {
+ return None;
+ }
+
+ let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio);
+ let cx = bbox[0] * self.width() as f32;
+ let cy = bbox[1] * self.height() as f32;
+ let w = bbox[2] * self.width() as f32;
+ let h = bbox[3] * self.height() as f32;
+ let x = cx - w / 2.;
+ let y = cy - h / 2.;
+ let x = x.max(0.0).min(image_width);
+ let y = y.max(0.0).min(image_height);
+
+ Some(
+ Bbox::default()
+ .with_xywh(x, y, w, h)
+ .with_id(class_id as _)
+ .with_name(&tokens[class_id])
+ .with_confidence(conf),
+ )
+ })
+ .collect();
+
+ if !y_bboxes.is_empty() {
+ Some(Y::default().with_bboxes(&y_bboxes))
+ } else {
+ None
+ }
+ })
+ .collect();
+ Ok(ys)
+ }
+
+ fn parse_texts(texts: &[&str]) -> String {
+ let mut y = String::new();
+ for text in texts.iter() {
+ if !text.is_empty() {
+ y.push_str(&format!("{} . ", text));
+ }
+ }
+ y
+ }
+
+ fn gen_text_self_attention_masks(encoding: &Encoding) -> Result {
+ let mut vs = encoding
+ .get_tokens()
+ .iter()
+ .map(|x| if x == "." { 1. } else { 0. })
+ .collect::>();
+
+ let n = vs.len();
+ vs[0] = 1.;
+ vs[n - 1] = 1.;
+ let mut ys = Array::zeros((n, n)).into_dyn();
+ let mut i_last = -1;
+ for (i, &v) in vs.iter().enumerate() {
+ if v == 0. {
+ if i_last == -1 {
+ i_last = i as isize;
+ } else {
+ i_last = -1;
+ }
+ } else if v == 1. {
+ if i_last == -1 {
+ ys.slice_mut(s![i, i]).fill(1.);
+ } else {
+ ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
+ .fill(1.);
+ }
+ i_last = -1;
+ } else {
+ continue;
+ }
+ }
+ Ok(X::from(ys))
+ }
+
+ pub fn conf_visual(&self) -> f32 {
+ self.confs_visual[0]
+ }
+
+ pub fn conf_textual(&self) -> f32 {
+ self.confs_textual[0]
+ }
+
+ pub fn batch(&self) -> isize {
+ self.batch.opt
+ }
+
+ pub fn width(&self) -> isize {
+ self.width.opt
+ }
+
+ pub fn height(&self) -> isize {
+ self.height.opt
+ }
+}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index da81975..237fe8f 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -5,6 +5,7 @@ mod clip;
mod db;
mod depth_anything;
mod dinov2;
+mod grounding_dino;
mod modnet;
mod rtmo;
mod sam;
@@ -18,6 +19,7 @@ pub use clip::Clip;
pub use db::DB;
pub use depth_anything::DepthAnything;
pub use dinov2::Dinov2;
+pub use grounding_dino::GroundingDINO;
pub use modnet::MODNet;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};
diff --git a/src/models/rtmo.rs b/src/models/rtmo.rs
index 72b2a8f..b3b5569 100644
--- a/src/models/rtmo.rs
+++ b/src/models/rtmo.rs
@@ -97,7 +97,7 @@ impl RTMO {
)
.with_confidence(confidence)
.with_id(0isize)
- .with_name(Some(String::from("Person"))),
+ .with_name("Person"),
);
// keypoints
diff --git a/src/models/sam.rs b/src/models/sam.rs
index 3eb7e7b..8d56e96 100644
--- a/src/models/sam.rs
+++ b/src/models/sam.rs
@@ -116,7 +116,7 @@ impl SAM {
pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result> {
let ys = self.encode(xs)?;
- self.decode(ys, xs, prompts)
+ self.decode(&ys, xs, prompts)
}
pub fn encode(&mut self, xs: &[DynamicImage]) -> Result {
@@ -139,7 +139,7 @@ impl SAM {
pub fn decode(
&mut self,
- xs: Xs,
+ xs: &Xs,
xs0: &[DynamicImage],
prompts: &[SamPrompt],
) -> Result> {
@@ -157,7 +157,9 @@ impl SAM {
let args = match self.kind {
SamKind::Sam | SamKind::MobileSam => {
vec![
- X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
+ X::from(image_embedding.into_dyn().into_owned())
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?, // image_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
@@ -167,10 +169,13 @@ impl SAM {
}
SamKind::SamHq => {
vec![
- X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
+ X::from(image_embedding.into_dyn().into_owned())
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?, // image_embedding
X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned())
.insert_axis(0)?
- .insert_axis(0)?, // intern_embedding
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?, // intern_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
@@ -180,14 +185,18 @@ impl SAM {
}
SamKind::EdgeSam => {
vec![
- X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
+ X::from(image_embedding.into_dyn().into_owned())
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
]
}
SamKind::Sam2 => {
vec![
- X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
+ X::from(image_embedding.into_dyn().into_owned())
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?,
X::from(
high_res_features_0
.unwrap()
@@ -195,7 +204,8 @@ impl SAM {
.into_dyn()
.into_owned(),
)
- .insert_axis(0)?,
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?,
X::from(
high_res_features_1
.unwrap()
@@ -203,7 +213,8 @@ impl SAM {
.into_dyn()
.into_owned(),
)
- .insert_axis(0)?,
+ .insert_axis(0)?
+ .repeat(0, self.batch() as usize)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
diff --git a/src/models/yolo.rs b/src/models/yolo.rs
index 82f3878..5d75221 100644
--- a/src/models/yolo.rs
+++ b/src/models/yolo.rs
@@ -215,13 +215,13 @@ impl Vision for YOLO {
} else {
slice_clss.into_owned()
};
- return Some(
- y.with_probs(
- &Prob::default()
- .with_probs(&x.into_raw_vec())
- .with_names(self.names.clone()),
- ),
- );
+ let mut probs = Prob::default().with_probs(&x.into_raw_vec());
+ if let Some(names) = &self.names {
+ probs =
+ probs.with_names(&names.iter().map(|x| x.as_str()).collect::>());
+ }
+
+ return Some(y.with_probs(&probs));
}
let image_width = xs0[idx].width() as f32;
@@ -312,41 +312,34 @@ impl Vision for YOLO {
(h, w, radians + std::f32::consts::PI / 2.)
};
let radians = radians % std::f32::consts::PI;
- (
- None,
- Some(
- Mbr::from_cxcywhr(
- cx as f64,
- cy as f64,
- w as f64,
- h as f64,
- radians as f64,
- )
- .with_confidence(confidence)
- .with_id(class_id as isize)
- .with_name(
- self.names
- .as_ref()
- .map(|names| names[class_id].clone()),
- ),
- ),
+
+ let mut mbr = Mbr::from_cxcywhr(
+ cx as f64,
+ cy as f64,
+ w as f64,
+ h as f64,
+ radians as f64,
)
+ .with_confidence(confidence)
+ .with_id(class_id as isize);
+ if let Some(names) = &self.names {
+ mbr = mbr.with_name(&names[class_id]);
+ }
+
+ (None, Some(mbr))
+ }
+ None => {
+ let mut bbox = Bbox::default()
+ .with_xywh(x, y, w, h)
+ .with_confidence(confidence)
+ .with_id(class_id as isize)
+ .with_id_born(i as isize);
+ if let Some(names) = &self.names {
+ bbox = bbox.with_name(&names[class_id]);
+ }
+
+ (Some(bbox), None)
}
- None => (
- Some(
- Bbox::default()
- .with_xywh(x, y, w, h)
- .with_confidence(confidence)
- .with_id(class_id as isize)
- .with_id_born(i as isize)
- .with_name(
- self.names
- .as_ref()
- .map(|names| names[class_id].clone()),
- ),
- ),
- None,
- ),
};
Some((y_bbox, y_mbr))
@@ -390,18 +383,18 @@ impl Vision for YOLO {
if kconf < self.kconfs[i] {
Keypoint::default()
} else {
- Keypoint::default()
+ let mut kpt = Keypoint::default()
.with_id(i as isize)
.with_confidence(kconf)
- .with_name(
- self.names_kpt
- .as_ref()
- .map(|names| names[i].clone()),
- )
.with_xy(
kx.max(0.0f32).min(image_width),
ky.max(0.0f32).min(image_height),
- )
+ );
+
+ if let Some(names) = &self.names_kpt {
+ kpt = kpt.with_name(&names[i]);
+ }
+ kpt
}
})
.collect::>();
@@ -468,23 +461,25 @@ impl Vision for YOLO {
contours
.into_par_iter()
.map(|x| {
- Polygon::default()
+ let mut polygon = Polygon::default()
.with_id(bbox.id())
- .with_points_imageproc(&x.points)
- .with_name(bbox.name().cloned())
+ .with_points_imageproc(&x.points);
+ if let Some(name) = bbox.name() {
+ polygon = polygon.with_name(name);
+ }
+ polygon
})
.max_by(|x, y| x.area().total_cmp(&y.area()))?
} else {
Polygon::default()
};
- Some((
- polygons,
- Mask::default()
- .with_mask(mask)
- .with_id(bbox.id())
- .with_name(bbox.name().cloned()),
- ))
+ let mut mask = Mask::default().with_mask(mask).with_id(bbox.id());
+ if let Some(name) = bbox.name() {
+ mask = mask.with_name(name);
+ }
+
+ Some((polygons, mask))
})
.collect::<(Vec<_>, Vec<_>)>();
diff --git a/src/models/yolop.rs b/src/models/yolop.rs
index 563e9d7..fbb1794 100644
--- a/src/models/yolop.rs
+++ b/src/models/yolop.rs
@@ -126,7 +126,7 @@ impl YOLOPv2 {
Polygon::default()
.with_id(0)
.with_points_imageproc(&x.points)
- .with_name(Some("Drivable area".to_string()))
+ .with_name("Drivable area")
})
.max_by(|x, y| x.area().total_cmp(&y.area()))
{
@@ -151,7 +151,7 @@ impl YOLOPv2 {
Polygon::default()
.with_id(1)
.with_points_imageproc(&x.points)
- .with_name(Some("Lane line".to_string()))
+ .with_name("Lane line")
})
.max_by(|x, y| x.area().total_cmp(&y.area()))
{
diff --git a/src/ys/bbox.rs b/src/ys/bbox.rs
index 11e4a47..a19fdfc 100644
--- a/src/ys/bbox.rs
+++ b/src/ys/bbox.rs
@@ -205,13 +205,13 @@ impl Bbox {
///
/// # Arguments
///
- /// * `x` - The optional name to be set.
+ /// * `x` - The name to be set.
///
/// # Returns
///
/// A `Bbox` instance with updated name.
- pub fn with_name(mut self, x: Option) -> Self {
- self.name = x;
+ pub fn with_name(mut self, x: &str) -> Self {
+ self.name = Some(x.to_string());
self
}
diff --git a/src/ys/keypoint.rs b/src/ys/keypoint.rs
index f22f6df..54a43c9 100644
--- a/src/ys/keypoint.rs
+++ b/src/ys/keypoint.rs
@@ -190,8 +190,8 @@ impl Keypoint {
self
}
- pub fn with_name(mut self, x: Option) -> Self {
- self.name = x;
+ pub fn with_name(mut self, x: &str) -> Self {
+ self.name = Some(x.to_string());
self
}
diff --git a/src/ys/mask.rs b/src/ys/mask.rs
index 5b3ccff..f3e91ea 100644
--- a/src/ys/mask.rs
+++ b/src/ys/mask.rs
@@ -41,8 +41,8 @@ impl Mask {
self
}
- pub fn with_name(mut self, x: Option) -> Self {
- self.name = x;
+ pub fn with_name(mut self, x: &str) -> Self {
+ self.name = Some(x.to_string());
self
}
diff --git a/src/ys/mbr.rs b/src/ys/mbr.rs
index 552fb6e..7325508 100644
--- a/src/ys/mbr.rs
+++ b/src/ys/mbr.rs
@@ -101,8 +101,8 @@ impl Mbr {
self
}
- pub fn with_name(mut self, x: Option) -> Self {
- self.name = x;
+ pub fn with_name(mut self, x: &str) -> Self {
+ self.name = Some(x.to_string());
self
}
diff --git a/src/ys/polygon.rs b/src/ys/polygon.rs
index f4df61d..62a8e1c 100644
--- a/src/ys/polygon.rs
+++ b/src/ys/polygon.rs
@@ -59,8 +59,8 @@ impl Polygon {
self
}
- pub fn with_name(mut self, x: Option) -> Self {
- self.name = x;
+ pub fn with_name(mut self, x: &str) -> Self {
+ self.name = Some(x.to_string());
self
}
diff --git a/src/ys/prob.rs b/src/ys/prob.rs
index 3e5ba80..f2aba92 100644
--- a/src/ys/prob.rs
+++ b/src/ys/prob.rs
@@ -12,8 +12,9 @@ impl std::fmt::Debug for Prob {
}
impl Prob {
- pub fn with_names(mut self, x: Option>) -> Self {
- self.names = x;
+ pub fn with_names(mut self, names: &[&str]) -> Self {
+ let names = names.iter().map(|x| x.to_string()).collect::>();
+ self.names = Some(names);
self
}