Add GroundingDINO (#30)

This commit is contained in:
Jamjamjon
2024-08-09 19:06:30 +08:00
committed by GitHub
parent 53d14ee2fb
commit b81b5e3cf5
24 changed files with 536 additions and 131 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.9"
version = "0.0.10"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"

View File

@ -5,7 +5,7 @@
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [DINOv2](https://github.com/facebookresearch/dinov2), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) and others.
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [DINOv2](https://github.com/facebookresearch/dinov2), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) and others.
| Segment Anything |
@ -55,6 +55,7 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
| [Depth-Anything<br />(v1, v2)](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
## Installation

View File

@ -10,7 +10,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// textual
let options_textual = Options::default()
.with_model("blip-textual-base.onnx")?
.with_tokenizer("tokenizer-blip.json")?
// .with_tokenizer("tokenizer-blip.json")?
.with_i00((1, 1, 4).into()) // input_id: batch
.with_i01((1, 1, 4).into()) // input_id: seq_len
.with_i10((1, 1, 4).into()) // attention_mask: batch
@ -23,9 +23,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut model = Blip::new(options_visual, options_textual)?;
// image caption (this demo use batch_size=1)
let x = vec![DataLoader::try_read("./assets/bus.jpg")?];
let _y = model.caption(&x, None, true)?; // unconditional
let y = model.caption(&x, Some("three man"), true)?; // conditional
let xs = vec![DataLoader::try_read("./assets/bus.jpg")?];
let image_embeddings = model.encode_images(&xs)?;
let _y = model.caption(&image_embeddings, None, true)?; // unconditional
let y = model.caption(&image_embeddings, Some("three man"), true)?; // conditional
println!("{:?}", y[0].texts());
Ok(())

View File

@ -10,7 +10,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// textual
let options_textual = Options::default()
.with_model("clip-b32-textual-dyn.onnx")?
.with_tokenizer("tokenizer-clip.json")?
// .with_tokenizer("tokenizer-clip.json")?
.with_i00((1, 1, 4).into())
.with_profile(false);

View File

@ -0,0 +1,40 @@
use usls::{models::GroundingDINO, Annotator, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = Options::default()
.with_i00((1, 1, 4).into())
.with_i02((640, 800, 1200).into())
.with_i03((640, 1200, 1200).into())
.with_i10((1, 1, 4).into())
.with_i11((256, 256, 512).into())
.with_i20((1, 1, 4).into())
.with_i21((256, 256, 512).into())
.with_i30((1, 1, 4).into())
.with_i31((256, 256, 512).into())
.with_i40((1, 1, 4).into())
.with_i41((256, 256, 512).into())
.with_i50((1, 1, 4).into())
.with_i51((256, 256, 512).into())
.with_i52((256, 256, 512).into())
.with_model("groundingdino-swint-ogc-dyn-u8.onnx")? // TODO: current onnx model does not support bs > 1
// .with_model("groundingdino-swint-ogc-dyn-f32.onnx")?
.with_confs(&[0.2])
.with_profile(false);
let mut model = GroundingDINO::new(opts)?;
// Load images and set class names
let x = [DataLoader::try_read("./assets/bus.jpg")?];
let texts = [
"person", "hand", "shoes", "bus", "dog", "cat", "sign", "tie", "monitor", "window",
"glasses", "tree", "head",
];
// Run and annotate
let y = model.run(&x, &texts)?;
let annotator = Annotator::default()
.with_bboxes_thickness(4)
.with_saveout("GroundingDINO");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -99,7 +99,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut model = SAM::new(options_encoder, options_decoder)?;
// Load image
let xs = vec![DataLoader::try_read("./assets/truck.jpg")?];
let xs = [
DataLoader::try_read("./assets/truck.jpg")?,
// DataLoader::try_read("./assets/dog.jpg")?,
];
// Build annotator
let annotator = Annotator::default().with_saveout(saveout);

View File

@ -7,7 +7,7 @@ use fast_image_resize::{
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
use ndarray::{s, Array, Axis, IxDyn};
use ndarray::{s, Array, Axis, IntoDimension, IxDyn};
use rayon::prelude::*;
pub enum Ops<'a> {
@ -20,6 +20,10 @@ pub enum Ops<'a> {
Nhwc2nchw,
Nchw2nhwc,
Norm,
Sigmoid,
Broadcast,
ToShape,
Repeat,
}
impl Ops<'_> {
@ -34,6 +38,41 @@ impl Ops<'_> {
Ok((x - min) / (max - min))
}
pub fn sigmoid(x: Array<f32, IxDyn>) -> Array<f32, IxDyn> {
x.mapv(|x| 1. / ((-x).exp() + 1.))
}
pub fn broadcast<D: IntoDimension + std::fmt::Debug + Copy>(
x: Array<f32, IxDyn>,
dim: D,
) -> Result<Array<f32, IxDyn>> {
match x.broadcast(dim) {
Some(x) => Ok(x.to_owned().into_dyn()),
None => anyhow::bail!(
"Failed to broadcast. Shape: {:?}, dim: {:?}",
x.shape(),
dim
),
}
}
pub fn repeat(x: Array<f32, IxDyn>, d: usize, n: usize) -> Result<Array<f32, IxDyn>> {
if d >= x.ndim() {
anyhow::bail!("Index {d} is out of bounds with size {}.", x.ndim());
} else {
let mut dim = x.shape().to_vec();
dim[d] = n;
Self::broadcast(x, dim.as_slice())
}
}
pub fn to_shape<D: ndarray::ShapeArg>(
x: Array<f32, IxDyn>,
dim: D,
) -> Result<Array<f32, IxDyn>> {
Ok(x.to_shape(dim).map(|x| x.to_owned().into_dyn())?)
}
pub fn standardize(
x: Array<f32, IxDyn>,
mean: &[f32],

View File

@ -73,10 +73,13 @@ pub struct Options {
pub nk: Option<usize>,
pub nm: Option<usize>,
pub confs: Vec<f32>,
pub confs2: Vec<f32>,
pub confs3: Vec<f32>,
pub kconfs: Vec<f32>,
pub iou: Option<f32>,
pub tokenizer: Option<String>,
pub vocab: Option<String>,
pub context_length: Option<usize>,
pub names: Option<Vec<String>>, // names
pub names2: Option<Vec<String>>, // names2
pub names3: Option<Vec<String>>, // names3
@ -152,11 +155,14 @@ impl Default for Options {
nc: None,
nk: None,
nm: None,
confs: vec![0.4f32],
confs: vec![0.3f32],
confs2: vec![0.3f32],
confs3: vec![0.3f32],
kconfs: vec![0.5f32],
iou: None,
tokenizer: None,
vocab: None,
context_length: None,
names: None,
names2: None,
names3: None,
@ -255,12 +261,17 @@ impl Options {
}
pub fn with_vocab(mut self, vocab: &str) -> Result<Self> {
self.vocab = Some(auto_load(vocab, Some("models"))?);
self.vocab = Some(auto_load(vocab, Some("tokenizers"))?);
Ok(self)
}
pub fn with_context_length(mut self, n: usize) -> Self {
self.context_length = Some(n);
self
}
pub fn with_tokenizer(mut self, tokenizer: &str) -> Result<Self> {
self.tokenizer = Some(auto_load(tokenizer, Some("models"))?);
self.tokenizer = Some(auto_load(tokenizer, Some("tokenizers"))?);
Ok(self)
}
@ -299,8 +310,18 @@ impl Options {
self
}
pub fn with_confs(mut self, confs: &[f32]) -> Self {
self.confs = confs.to_vec();
pub fn with_confs(mut self, x: &[f32]) -> Self {
self.confs = x.to_vec();
self
}
pub fn with_confs2(mut self, x: &[f32]) -> Self {
self.confs2 = x.to_vec();
self
}
pub fn with_confs3(mut self, x: &[f32]) -> Self {
self.confs3 = x.to_vec();
self
}

View File

@ -321,6 +321,9 @@ impl OrtEngine {
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
TensorElementType::Bool => {
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
}
_ => todo!(),
};
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));

View File

@ -1,6 +1,6 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Dim, IxDyn, IxDynImpl};
use ndarray::{Array, Dim, IntoDimension, IxDyn, IxDynImpl};
use crate::Ops;
@ -51,12 +51,28 @@ impl X {
Ops::InsertAxis(d) => y.insert_axis(*d)?,
Ops::Nhwc2nchw => y.nhwc2nchw()?,
Ops::Nchw2nhwc => y.nchw2nhwc()?,
Ops::Sigmoid => y.sigmoid()?,
_ => todo!(),
}
}
Ok(y)
}
pub fn sigmoid(mut self) -> Result<Self> {
self.0 = Ops::sigmoid(self.0);
Ok(self)
}
pub fn broadcast<D: IntoDimension + std::fmt::Debug + Copy>(mut self, dim: D) -> Result<Self> {
self.0 = Ops::broadcast(self.0, dim)?;
Ok(self)
}
pub fn to_shape<D: ndarray::ShapeArg>(mut self, dim: D) -> Result<Self> {
self.0 = Ops::to_shape(self.0, dim)?;
Ok(self)
}
pub fn permute(mut self, shape: &[usize]) -> Result<Self> {
self.0 = Ops::permute(self.0, shape)?;
Ok(self)
@ -77,6 +93,11 @@ impl X {
Ok(self)
}
pub fn repeat(mut self, d: usize, n: usize) -> Result<Self> {
self.0 = Ops::repeat(self.0, d, n)?;
Ok(self)
}
pub fn dims(&self) -> &[usize] {
self.0.shape()
}

View File

@ -1,11 +1,12 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis, IxDyn};
use ndarray::s;
use std::io::Write;
use tokenizers::Tokenizer;
use crate::{
Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y,
auto_load, Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs,
X, Y,
};
#[derive(Debug)]
@ -29,7 +30,19 @@ impl Blip {
visual.height().to_owned(),
visual.width().to_owned(),
);
let tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
let tokenizer = match options_textual.tokenizer {
Some(x) => x,
None => match auto_load("tokenizer-blip.json", Some("tokenizers")) {
Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
Ok(x) => x,
},
};
let tokenizer = match Tokenizer::from_file(tokenizer) {
Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
Ok(x) => x,
};
let tokenizer = TokenizerStream::new(tokenizer);
visual.dry_run()?;
textual.dry_run()?;
@ -64,17 +77,14 @@ impl Blip {
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}
pub fn caption(
&mut self,
x: &[DynamicImage],
prompt: Option<&str>,
show: bool,
) -> Result<Vec<Y>> {
pub fn caption(&mut self, xs: &Y, prompt: Option<&str>, show: bool) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
let image_embeds = self.encode_images(x)?;
let image_embeds = image_embeds.embedding().unwrap();
let image_embeds_attn_mask: Array<f32, IxDyn> =
Array::ones((1, image_embeds.data().shape()[1])).into_dyn();
let image_embeds = match xs.embedding() {
Some(x) => X::from(x.data().to_owned()),
None => anyhow::bail!("No image embeddings found."),
};
let image_embeds_attn_mask = X::ones(&[self.batch_visual(), image_embeds.dims()[1]]);
let mut y_text = String::new();
// conditional
@ -86,13 +96,11 @@ impl Blip {
vec![0.0f32]
}
Some(prompt) => {
let encodings = self.tokenizer.tokenizer().encode(prompt, false);
let ids: Vec<f32> = encodings
.unwrap()
.get_ids()
.iter()
.map(|x| *x as f32)
.collect();
let encodings = match self.tokenizer.tokenizer().encode(prompt, false) {
Err(err) => anyhow::bail!("{}", err),
Ok(x) => x,
};
let ids: Vec<f32> = encodings.get_ids().iter().map(|x| *x as f32).collect();
if show {
print!("[Conditional]: {} ", prompt);
}
@ -103,18 +111,16 @@ impl Blip {
let mut logits_sampler = LogitsSampler::new();
loop {
let input_ids_nd: Array<f32, IxDyn> = Array::from_vec(input_ids.to_owned()).into_dyn();
let input_ids_nd = input_ids_nd.insert_axis(Axis(0));
let input_ids_nd = X::from(input_ids_nd);
let input_ids_attn_mask: Array<f32, IxDyn> =
Array::ones(input_ids_nd.shape()).into_dyn();
let input_ids_attn_mask = X::from(input_ids_attn_mask);
let input_ids_nd = X::from(input_ids.to_owned())
.insert_axis(0)?
.repeat(0, self.batch_textual())?;
let input_ids_attn_mask = X::ones(input_ids_nd.dims());
let y = self.textual.run(Xs::from(vec![
input_ids_nd,
input_ids_attn_mask,
X::from(image_embeds.data().to_owned()),
X::from(image_embeds_attn_mask.to_owned()),
image_embeds.clone(),
image_embeds_attn_mask.clone(),
]))?; // N, length, vocab_size
let y = y[0].slice(s!(0, -1.., ..));
let logits = y.slice(s!(0, ..)).to_vec();

View File

@ -3,7 +3,7 @@ use image::DynamicImage;
use ndarray::Array2;
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use crate::{auto_load, Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
#[derive(Debug)]
pub struct Clip {
@ -28,7 +28,19 @@ impl Clip {
visual.inputs_minoptmax()[0][2].to_owned(),
visual.inputs_minoptmax()[0][3].to_owned(),
);
let mut tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap();
let tokenizer = match options_textual.tokenizer {
Some(x) => x,
None => match auto_load("tokenizer-clip.json", Some("tokenizers")) {
Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
Ok(x) => x,
},
};
let mut tokenizer = match Tokenizer::from_file(tokenizer) {
Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
Ok(x) => x,
};
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::Fixed(context_length),
direction: PaddingDirection::Right,
@ -74,10 +86,10 @@ impl Clip {
}
pub fn encode_texts(&mut self, texts: &[String]) -> Result<Y> {
let encodings = self
.tokenizer
.encode_batch(texts.to_owned(), false)
.unwrap();
let encodings = match self.tokenizer.encode_batch(texts.to_owned(), false) {
Err(err) => anyhow::bail!("{:?}", err),
Ok(x) => x,
};
let xs: Vec<f32> = encodings
.iter()
.flat_map(|i| i.get_ids().iter().map(|&b| b as f32))

View File

@ -0,0 +1,249 @@
use crate::{auto_load, Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis};
use rayon::prelude::*;
use tokenizers::{Encoding, Tokenizer};
#[derive(Debug)]
pub struct GroundingDINO {
pub engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
tokenizer: Tokenizer,
pub context_length: usize,
confs_visual: DynConf,
confs_textual: DynConf,
}
impl GroundingDINO {
pub fn new(options: Options) -> Result<Self> {
let mut engine = OrtEngine::new(&options)?;
let (batch, height, width) = (
engine.inputs_minoptmax()[0][0].to_owned(),
engine.inputs_minoptmax()[0][2].to_owned(),
engine.inputs_minoptmax()[0][3].to_owned(),
);
let context_length = options.context_length.unwrap_or(256);
// let special_tokens = ["[CLS]", "[SEP]", ".", "?"];
let tokenizer = match options.tokenizer {
Some(x) => x,
None => match auto_load("tokenizer-groundingdino.json", Some("tokenizers")) {
Err(err) => anyhow::bail!("No tokenizer's file found: {:?}", err),
Ok(x) => x,
},
};
let tokenizer = match Tokenizer::from_file(tokenizer) {
Err(err) => anyhow::bail!("Failed to build tokenizer: {:?}", err),
Ok(x) => x,
};
let confs_visual = DynConf::new(&options.confs, 1);
let confs_textual = DynConf::new(&options.confs, 1);
engine.dry_run()?;
Ok(Self {
engine,
batch,
height,
width,
tokenizer,
context_length,
confs_visual,
confs_textual,
})
}
pub fn run(&mut self, xs: &[DynamicImage], texts: &[&str]) -> Result<Vec<Y>> {
// image embeddings
let image_embeddings = X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"CatmullRom",
114,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
Ops::Nhwc2nchw,
])?;
// encoding
let text = Self::parse_texts(texts);
let encoding = match self.tokenizer.encode(text, true) {
Err(err) => anyhow::bail!("{}", err),
Ok(x) => x,
};
let tokens = encoding.get_tokens();
// input_ids
let input_ids = X::from(
encoding
.get_ids()
.iter()
.map(|&x| x as f32)
.collect::<Vec<_>>(),
)
.insert_axis(0)?
.repeat(0, self.batch() as usize)?;
// token_type_ids
let token_type_ids = X::zeros(&[self.batch() as usize, tokens.len()]);
// attention_mask
let attention_mask = X::ones(&[self.batch() as usize, tokens.len()]);
// position_ids
let position_ids = X::from(
encoding
.get_tokens()
.iter()
.map(|x| if x == "." { 1. } else { 0. })
.collect::<Vec<_>>(),
)
.insert_axis(0)?
.repeat(0, self.batch() as usize)?;
// text_self_attention_masks
let text_self_attention_masks = Self::gen_text_self_attention_masks(&encoding)?
.insert_axis(0)?
.repeat(0, self.batch() as usize)?;
// run
let ys = self.engine.run(Xs::from(vec![
image_embeddings,
input_ids,
attention_mask,
position_ids,
token_type_ids,
text_self_attention_masks,
]))?;
// post-process
self.postprocess(ys, xs, tokens)
}
fn postprocess(&self, xs: Xs, xs0: &[DynamicImage], tokens: &[String]) -> Result<Vec<Y>> {
let ys: Vec<Y> = xs["logits"]
.axis_iter(Axis(0))
.into_par_iter()
.enumerate()
.filter_map(|(idx, logits)| {
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
let ratio =
(self.width() as f32 / image_width).min(self.height() as f32 / image_height);
let y_bboxes: Vec<Bbox> = logits
.axis_iter(Axis(0))
.into_par_iter()
.enumerate()
.filter_map(|(i, clss)| {
let (class_id, &conf) = clss
.mapv(|x| 1. / ((-x).exp() + 1.))
.iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))?;
if conf < self.conf_visual() {
return None;
}
let bbox = xs["boxes"].slice(s![idx, i, ..]).mapv(|x| x / ratio);
let cx = bbox[0] * self.width() as f32;
let cy = bbox[1] * self.height() as f32;
let w = bbox[2] * self.width() as f32;
let h = bbox[3] * self.height() as f32;
let x = cx - w / 2.;
let y = cy - h / 2.;
let x = x.max(0.0).min(image_width);
let y = y.max(0.0).min(image_height);
Some(
Bbox::default()
.with_xywh(x, y, w, h)
.with_id(class_id as _)
.with_name(&tokens[class_id])
.with_confidence(conf),
)
})
.collect();
if !y_bboxes.is_empty() {
Some(Y::default().with_bboxes(&y_bboxes))
} else {
None
}
})
.collect();
Ok(ys)
}
fn parse_texts(texts: &[&str]) -> String {
let mut y = String::new();
for text in texts.iter() {
if !text.is_empty() {
y.push_str(&format!("{} . ", text));
}
}
y
}
fn gen_text_self_attention_masks(encoding: &Encoding) -> Result<X> {
let mut vs = encoding
.get_tokens()
.iter()
.map(|x| if x == "." { 1. } else { 0. })
.collect::<Vec<_>>();
let n = vs.len();
vs[0] = 1.;
vs[n - 1] = 1.;
let mut ys = Array::zeros((n, n)).into_dyn();
let mut i_last = -1;
for (i, &v) in vs.iter().enumerate() {
if v == 0. {
if i_last == -1 {
i_last = i as isize;
} else {
i_last = -1;
}
} else if v == 1. {
if i_last == -1 {
ys.slice_mut(s![i, i]).fill(1.);
} else {
ys.slice_mut(s![i_last as _..i + 1, i_last as _..i + 1])
.fill(1.);
}
i_last = -1;
} else {
continue;
}
}
Ok(X::from(ys))
}
pub fn conf_visual(&self) -> f32 {
self.confs_visual[0]
}
pub fn conf_textual(&self) -> f32 {
self.confs_textual[0]
}
pub fn batch(&self) -> isize {
self.batch.opt
}
pub fn width(&self) -> isize {
self.width.opt
}
pub fn height(&self) -> isize {
self.height.opt
}
}

View File

@ -5,6 +5,7 @@ mod clip;
mod db;
mod depth_anything;
mod dinov2;
mod grounding_dino;
mod modnet;
mod rtmo;
mod sam;
@ -18,6 +19,7 @@ pub use clip::Clip;
pub use db::DB;
pub use depth_anything::DepthAnything;
pub use dinov2::Dinov2;
pub use grounding_dino::GroundingDINO;
pub use modnet::MODNet;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};

View File

@ -97,7 +97,7 @@ impl RTMO {
)
.with_confidence(confidence)
.with_id(0isize)
.with_name(Some(String::from("Person"))),
.with_name("Person"),
);
// keypoints

View File

@ -116,7 +116,7 @@ impl SAM {
pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let ys = self.encode(xs)?;
self.decode(ys, xs, prompts)
self.decode(&ys, xs, prompts)
}
pub fn encode(&mut self, xs: &[DynamicImage]) -> Result<Xs> {
@ -139,7 +139,7 @@ impl SAM {
pub fn decode(
&mut self,
xs: Xs,
xs: &Xs,
xs0: &[DynamicImage],
prompts: &[SamPrompt],
) -> Result<Vec<Y>> {
@ -157,7 +157,9 @@ impl SAM {
let args = match self.kind {
SamKind::Sam | SamKind::MobileSam => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch() as usize)?, // image_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
@ -167,10 +169,13 @@ impl SAM {
}
SamKind::SamHq => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch() as usize)?, // image_embedding
X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned())
.insert_axis(0)?
.insert_axis(0)?, // intern_embedding
.insert_axis(0)?
.repeat(0, self.batch() as usize)?, // intern_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
@ -180,14 +185,18 @@ impl SAM {
}
SamKind::EdgeSam => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch() as usize)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
]
}
SamKind::Sam2 => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch() as usize)?,
X::from(
high_res_features_0
.unwrap()
@ -195,7 +204,8 @@ impl SAM {
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
.insert_axis(0)?
.repeat(0, self.batch() as usize)?,
X::from(
high_res_features_1
.unwrap()
@ -203,7 +213,8 @@ impl SAM {
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
.insert_axis(0)?
.repeat(0, self.batch() as usize)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input

View File

@ -215,13 +215,13 @@ impl Vision for YOLO {
} else {
slice_clss.into_owned()
};
return Some(
y.with_probs(
&Prob::default()
.with_probs(&x.into_raw_vec())
.with_names(self.names.clone()),
),
);
let mut probs = Prob::default().with_probs(&x.into_raw_vec());
if let Some(names) = &self.names {
probs =
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
}
return Some(y.with_probs(&probs));
}
let image_width = xs0[idx].width() as f32;
@ -312,41 +312,34 @@ impl Vision for YOLO {
(h, w, radians + std::f32::consts::PI / 2.)
};
let radians = radians % std::f32::consts::PI;
(
None,
Some(
Mbr::from_cxcywhr(
cx as f64,
cy as f64,
w as f64,
h as f64,
radians as f64,
)
.with_confidence(confidence)
.with_id(class_id as isize)
.with_name(
self.names
.as_ref()
.map(|names| names[class_id].clone()),
),
),
let mut mbr = Mbr::from_cxcywhr(
cx as f64,
cy as f64,
w as f64,
h as f64,
radians as f64,
)
.with_confidence(confidence)
.with_id(class_id as isize);
if let Some(names) = &self.names {
mbr = mbr.with_name(&names[class_id]);
}
(None, Some(mbr))
}
None => {
let mut bbox = Bbox::default()
.with_xywh(x, y, w, h)
.with_confidence(confidence)
.with_id(class_id as isize)
.with_id_born(i as isize);
if let Some(names) = &self.names {
bbox = bbox.with_name(&names[class_id]);
}
(Some(bbox), None)
}
None => (
Some(
Bbox::default()
.with_xywh(x, y, w, h)
.with_confidence(confidence)
.with_id(class_id as isize)
.with_id_born(i as isize)
.with_name(
self.names
.as_ref()
.map(|names| names[class_id].clone()),
),
),
None,
),
};
Some((y_bbox, y_mbr))
@ -390,18 +383,18 @@ impl Vision for YOLO {
if kconf < self.kconfs[i] {
Keypoint::default()
} else {
Keypoint::default()
let mut kpt = Keypoint::default()
.with_id(i as isize)
.with_confidence(kconf)
.with_name(
self.names_kpt
.as_ref()
.map(|names| names[i].clone()),
)
.with_xy(
kx.max(0.0f32).min(image_width),
ky.max(0.0f32).min(image_height),
)
);
if let Some(names) = &self.names_kpt {
kpt = kpt.with_name(&names[i]);
}
kpt
}
})
.collect::<Vec<_>>();
@ -468,23 +461,25 @@ impl Vision for YOLO {
contours
.into_par_iter()
.map(|x| {
Polygon::default()
let mut polygon = Polygon::default()
.with_id(bbox.id())
.with_points_imageproc(&x.points)
.with_name(bbox.name().cloned())
.with_points_imageproc(&x.points);
if let Some(name) = bbox.name() {
polygon = polygon.with_name(name);
}
polygon
})
.max_by(|x, y| x.area().total_cmp(&y.area()))?
} else {
Polygon::default()
};
Some((
polygons,
Mask::default()
.with_mask(mask)
.with_id(bbox.id())
.with_name(bbox.name().cloned()),
))
let mut mask = Mask::default().with_mask(mask).with_id(bbox.id());
if let Some(name) = bbox.name() {
mask = mask.with_name(name);
}
Some((polygons, mask))
})
.collect::<(Vec<_>, Vec<_>)>();

View File

@ -126,7 +126,7 @@ impl YOLOPv2 {
Polygon::default()
.with_id(0)
.with_points_imageproc(&x.points)
.with_name(Some("Drivable area".to_string()))
.with_name("Drivable area")
})
.max_by(|x, y| x.area().total_cmp(&y.area()))
{
@ -151,7 +151,7 @@ impl YOLOPv2 {
Polygon::default()
.with_id(1)
.with_points_imageproc(&x.points)
.with_name(Some("Lane line".to_string()))
.with_name("Lane line")
})
.max_by(|x, y| x.area().total_cmp(&y.area()))
{

View File

@ -205,13 +205,13 @@ impl Bbox {
///
/// # Arguments
///
/// * `x` - The optional name to be set.
/// * `x` - The name to be set.
///
/// # Returns
///
/// A `Bbox` instance with updated name.
pub fn with_name(mut self, x: Option<String>) -> Self {
self.name = x;
pub fn with_name(mut self, x: &str) -> Self {
self.name = Some(x.to_string());
self
}

View File

@ -190,8 +190,8 @@ impl Keypoint {
self
}
pub fn with_name(mut self, x: Option<String>) -> Self {
self.name = x;
pub fn with_name(mut self, x: &str) -> Self {
self.name = Some(x.to_string());
self
}

View File

@ -41,8 +41,8 @@ impl Mask {
self
}
pub fn with_name(mut self, x: Option<String>) -> Self {
self.name = x;
pub fn with_name(mut self, x: &str) -> Self {
self.name = Some(x.to_string());
self
}

View File

@ -101,8 +101,8 @@ impl Mbr {
self
}
pub fn with_name(mut self, x: Option<String>) -> Self {
self.name = x;
pub fn with_name(mut self, x: &str) -> Self {
self.name = Some(x.to_string());
self
}

View File

@ -59,8 +59,8 @@ impl Polygon {
self
}
pub fn with_name(mut self, x: Option<String>) -> Self {
self.name = x;
pub fn with_name(mut self, x: &str) -> Self {
self.name = Some(x.to_string());
self
}

View File

@ -12,8 +12,9 @@ impl std::fmt::Debug for Prob {
}
impl Prob {
pub fn with_names(mut self, x: Option<Vec<String>>) -> Self {
self.names = x;
pub fn with_names(mut self, names: &[&str]) -> Self {
let names = names.iter().map(|x| x.to_string()).collect::<Vec<String>>();
self.names = Some(names);
self
}