diff --git a/README.md b/README.md
index eb7a496..f03afd1 100644
--- a/README.md
+++ b/README.md
@@ -116,6 +116,7 @@
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
+| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation Answering | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
diff --git a/examples/rmbg/README.md b/examples/rmbg/README.md
new file mode 100644
index 0000000..39c6a1f
--- /dev/null
+++ b/examples/rmbg/README.md
@@ -0,0 +1,9 @@
+## Quick Start
+
+```shell
+cargo run -r --example rmbg -- --ver 1.4 --dtype fp16
+```
+
+## Results
+
+
diff --git a/examples/rmbg/main.rs b/examples/rmbg/main.rs
new file mode 100644
index 0000000..cb34be0
--- /dev/null
+++ b/examples/rmbg/main.rs
@@ -0,0 +1,58 @@
+use usls::{models::RMBG, Annotator, DataLoader, Options};
+
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// dtype
+ #[argh(option, default = "String::from(\"auto\")")]
+ dtype: String,
+
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+
+ /// version
+ #[argh(option, default = "1.4")]
+ ver: f32,
+}
+
+fn main() -> anyhow::Result<()> {
+ tracing_subscriber::fmt()
+ .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
+ .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
+ .init();
+ let args: Args = argh::from_env();
+
+ let options = match args.ver {
+ 1.4 => Options::rmbg1_4(),
+ 2.0 => Options::rmbg2_0(),
+ _ => unreachable!("Unsupported version"),
+ };
+
+ // build model
+ let options = options
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?;
+ let mut model = RMBG::new(options)?;
+
+ // load image
+ let xs = DataLoader::try_read_n(&["./assets/cat.png"])?;
+
+ // run
+ let ys = model.forward(&xs)?;
+
+ // annotate
+ let annotator = Annotator::default();
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ annotator.annotate(x, y)?.save(format!(
+ "{}.jpg",
+ usls::Dir::Current
+ .base_dir_with_subs(&["runs", model.spec()])?
+ .join(usls::timestamp(None))
+ .display(),
+ ))?;
+ }
+
+ Ok(())
+}
diff --git a/examples/yolo-sam2/README.md b/examples/yolo-sam2/README.md
index 84dfb0f..93c41f8 100644
--- a/examples/yolo-sam2/README.md
+++ b/examples/yolo-sam2/README.md
@@ -1,7 +1,7 @@
## Quick Start
```shell
-cargo run -r -F cuda --example yolo-sam -- --device cuda
+cargo run -r -F cuda --example yolo-sam2 -- --device cuda
```
## Results
diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs
index bcf9634..8b60d87 100644
--- a/examples/yolo-sam2/main.rs
+++ b/examples/yolo-sam2/main.rs
@@ -39,15 +39,11 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// build annotator
- let annotator = Annotator::default()
- .with_polygon_style(
- Style::polygon()
- .with_visible(true)
- .with_text_visible(true)
- .show_id(true)
- .show_name(true),
- )
- .with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
+ let annotator = Annotator::default().with_mask_style(
+ Style::mask()
+ .with_draw_mask_polygon_largest(true)
+ .with_draw_mask_hbbs(true),
+ );
// run & annotate
let ys_det = yolo.forward(&xs)?;
diff --git a/src/inference/polygon.rs b/src/inference/polygon.rs
index a270f68..1eecb96 100644
--- a/src/inference/polygon.rs
+++ b/src/inference/polygon.rs
@@ -123,12 +123,23 @@ impl Polygon {
pub fn hbb(&self) -> Option {
use geo::BoundingRect;
self.polygon.bounding_rect().map(|x| {
- Hbb::default().with_xyxy(
+ let mut hbb = Hbb::default().with_xyxy(
x.min().x as f32,
x.min().y as f32,
x.max().x as f32,
x.max().y as f32,
- )
+ );
+ if let Some(id) = self.id() {
+ hbb = hbb.with_id(id);
+ }
+ if let Some(name) = self.name() {
+ hbb = hbb.with_name(name);
+ }
+ if let Some(confidence) = self.confidence() {
+ hbb = hbb.with_confidence(confidence);
+ }
+
+ hbb
})
}
@@ -138,11 +149,22 @@ impl Polygon {
let xy4 = x
.exterior()
.coords()
- // .iter()
.map(|c| [c.x as f32, c.y as f32])
.collect::>();
- Obb::from(xy4)
+ let mut obb = Obb::from(xy4);
+
+ if let Some(id) = self.id() {
+ obb = obb.with_id(id);
+ }
+ if let Some(name) = self.name() {
+ obb = obb.with_name(name);
+ }
+ if let Some(confidence) = self.confidence() {
+ obb = obb.with_confidence(confidence);
+ }
+
+ obb
})
}
diff --git a/src/models/depth_anything/impl.rs b/src/models/depth_anything/impl.rs
index c21a79b..f094cc6 100644
--- a/src/models/depth_anything/impl.rs
+++ b/src/models/depth_anything/impl.rs
@@ -75,12 +75,7 @@ impl DepthAnything {
false,
"Bilinear",
)?;
- let luma: image::ImageBuffer, Vec<_>> =
- match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
- None => continue,
- Some(x) => x,
- };
- ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
+ ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
}
Ok(ys)
diff --git a/src/models/depth_pro/impl.rs b/src/models/depth_pro/impl.rs
index 32f821f..6e1d254 100644
--- a/src/models/depth_pro/impl.rs
+++ b/src/models/depth_pro/impl.rs
@@ -76,12 +76,7 @@ impl DepthPro {
false,
"Bilinear",
)?;
- let luma: image::ImageBuffer, Vec<_>> =
- match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
- None => continue,
- Some(x) => x,
- };
- ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
+ ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
}
Ok(ys)
diff --git a/src/models/grounding_dino/config.rs b/src/models/grounding_dino/config.rs
index 0ec4f00..ce7096d 100644
--- a/src/models/grounding_dino/config.rs
+++ b/src/models/grounding_dino/config.rs
@@ -19,4 +19,8 @@ impl crate::Options {
pub fn grounding_dino_tiny() -> Self {
Self::grounding_dino().with_model_file("swint-ogc.onnx")
}
+
+ pub fn grounding_dino_base() -> Self {
+ Self::grounding_dino().with_model_file("swinb-cogcoor.onnx")
+ }
}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index ad5f42b..13303dc 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -21,6 +21,7 @@ mod owl;
mod picodet;
mod pipeline;
mod rfdetr;
+mod rmbg;
mod rtdetr;
mod rtmo;
mod sam;
@@ -47,6 +48,7 @@ pub use owl::*;
pub use picodet::*;
pub use pipeline::*;
pub use rfdetr::*;
+pub use rmbg::*;
pub use rtdetr::*;
pub use rtmo::*;
pub use sam::*;
diff --git a/src/models/rmbg/README.md b/src/models/rmbg/README.md
new file mode 100644
index 0000000..be44b75
--- /dev/null
+++ b/src/models/rmbg/README.md
@@ -0,0 +1,9 @@
+# RMBG: BRIA Background Removal
+
+## Official Repository
+
+The official repository can be found on: [HuggingFace](https://huggingface.co/briaai/RMBG-2.0)
+
+## Example
+
+Refer to the [example](../../../examples/rmbg)
diff --git a/src/models/rmbg/config.rs b/src/models/rmbg/config.rs
new file mode 100644
index 0000000..e44ce8f
--- /dev/null
+++ b/src/models/rmbg/config.rs
@@ -0,0 +1,24 @@
+/// Model configuration for `RMBG-2.0`
+impl crate::Options {
+ pub fn rmbg() -> Self {
+ Self::default()
+ .with_model_name("rmbg")
+ .with_model_ixx(0, 0, 1.into())
+ .with_model_ixx(0, 2, 1024.into())
+ .with_model_ixx(0, 3, 1024.into())
+ }
+
+ pub fn rmbg1_4() -> Self {
+ Self::rmbg()
+ .with_image_mean(&[0.5, 0.5, 0.5])
+ .with_image_std(&[1., 1., 1.])
+ .with_model_file("1.4.onnx")
+ }
+
+ pub fn rmbg2_0() -> Self {
+ Self::rmbg()
+ .with_image_mean(&[0.485, 0.456, 0.406])
+ .with_image_std(&[0.229, 0.224, 0.225])
+ .with_model_file("2.0.onnx")
+ }
+}
diff --git a/src/models/rmbg/impl.rs b/src/models/rmbg/impl.rs
new file mode 100644
index 0000000..f886fa6
--- /dev/null
+++ b/src/models/rmbg/impl.rs
@@ -0,0 +1,93 @@
+use aksr::Builder;
+use anyhow::Result;
+
+use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
+
+#[derive(Builder, Debug)]
+pub struct RMBG {
+ engine: Engine,
+ height: usize,
+ width: usize,
+ batch: usize,
+ ts: Ts,
+ spec: String,
+ processor: Processor,
+}
+
+impl RMBG {
+ pub fn new(options: Options) -> Result {
+ let engine = options.to_engine()?;
+ let spec = engine.spec().to_string();
+ let (batch, height, width, ts) = (
+ engine.batch().opt(),
+ engine.try_height().unwrap_or(&1024.into()).opt(),
+ engine.try_width().unwrap_or(&1024.into()).opt(),
+ engine.ts().clone(),
+ );
+ let processor = options
+ .to_processor()?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
+ Ok(Self {
+ engine,
+ height,
+ width,
+ batch,
+ ts,
+ spec,
+ processor,
+ })
+ }
+
+ fn preprocess(&mut self, xs: &[Image]) -> Result {
+ Ok(self.processor.process_images(xs)?.into())
+ }
+
+ fn inference(&mut self, xs: Xs) -> Result {
+ self.engine.run(xs)
+ }
+
+ pub fn forward(&mut self, xs: &[Image]) -> Result> {
+ let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? });
+ let ys = elapsed!("inference", self.ts, { self.inference(ys)? });
+ let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? });
+
+ Ok(ys)
+ }
+
+ pub fn summary(&mut self) {
+ self.ts.summary();
+ }
+
+ fn postprocess(&mut self, xs: Xs) -> Result> {
+ let mut ys: Vec = Vec::new();
+ for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() {
+ // image size
+ let (h1, w1) = (
+ self.processor.images_transform_info[idx].height_src,
+ self.processor.images_transform_info[idx].width_src,
+ );
+ let v = luma.into_owned().into_raw_vec_and_offset().0;
+ let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
+ let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
+ let v = v
+ .iter()
+ .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
+ .collect::>();
+
+ let luma = Ops::resize_luma8_u8(
+ &v,
+ self.width() as _,
+ self.height() as _,
+ w1 as _,
+ h1 as _,
+ false,
+ "Bilinear",
+ )?;
+ ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
+ }
+
+ Ok(ys)
+ }
+}
diff --git a/src/models/rmbg/mod.rs b/src/models/rmbg/mod.rs
new file mode 100644
index 0000000..fbd2b75
--- /dev/null
+++ b/src/models/rmbg/mod.rs
@@ -0,0 +1,4 @@
+mod config;
+mod r#impl;
+
+pub use r#impl::*;
diff --git a/src/viz/drawable/hbb.rs b/src/viz/drawable/hbb.rs
index b5b5728..6837e59 100644
--- a/src/viz/drawable/hbb.rs
+++ b/src/viz/drawable/hbb.rs
@@ -34,7 +34,10 @@ impl Drawable for Hbb {
imageproc::drawing::draw_filled_rect_mut(
&mut overlay,
imageproc::rect::Rect::at(self.xmin().round() as i32, self.ymin().round() as i32)
- .of_size(self.width().round() as u32, self.height().round() as u32),
+ .of_size(
+ (self.width().round() as u32).max(1),
+ (self.height().round() as u32).max(1),
+ ),
Rgba(style.color().fill.unwrap().into()),
);
image::imageops::overlay(canvas, &overlay, 0, 0);
@@ -43,7 +46,7 @@ impl Drawable for Hbb {
if style.draw_outline() {
let short_side_threshold =
self.width().min(self.height()) * style.thickness_threshold();
- let thickness = style.thickness().min(short_side_threshold as usize);
+ let thickness = style.thickness().min(short_side_threshold as usize).max(1);
for i in 0..thickness {
imageproc::drawing::draw_hollow_rect_mut(
canvas,
@@ -52,8 +55,8 @@ impl Drawable for Hbb {
(self.ymin().round() as i32) - (i as i32),
)
.of_size(
- (self.width().round() as u32) + (2 * i as u32),
- (self.height().round() as u32) + (2 * i as u32),
+ ((self.width().round() as u32) + (2 * i as u32)).max(1),
+ ((self.height().round() as u32) + (2 * i as u32)).max(1),
),
Rgba(style.color().outline.unwrap().into()),
);
@@ -78,7 +81,7 @@ impl Drawable for Hbb {
if style.draw_text() {
let short_side_threshold =
self.width().min(self.height()) * style.thickness_threshold();
- let thickness = style.thickness().min(short_side_threshold as usize);
+ let thickness = style.thickness().min(short_side_threshold as usize).max(1);
let label = self.meta().label(
style.id(),
diff --git a/src/viz/drawable/mask.rs b/src/viz/drawable/mask.rs
index e9652d4..474c7c5 100644
--- a/src/viz/drawable/mask.rs
+++ b/src/viz/drawable/mask.rs
@@ -89,12 +89,29 @@ impl Drawable for Vec {
polygon.draw(ctx, canvas)?;
}
}
+
if style.draw_mask_polygons() {
for polygon in mask.polygons() {
polygon.draw(ctx, canvas)?;
}
}
+ if style.draw_mask_hbbs() {
+ if let Some(polygon) = mask.polygon() {
+ if let Some(hbb) = polygon.hbb() {
+ hbb.draw(ctx, canvas)?;
+ }
+ }
+ }
+
+ if style.draw_mask_obbs() {
+ if let Some(polygon) = mask.polygon() {
+ if let Some(obb) = polygon.obb() {
+ obb.draw(ctx, canvas)?;
+ }
+ }
+ }
+
if style.visible() {
masks_visible.push(mask);
}
@@ -140,6 +157,22 @@ impl Drawable for Mask {
}
}
+ if style.draw_mask_hbbs() {
+ if let Some(polygon) = self.polygon() {
+ if let Some(hbb) = polygon.hbb() {
+ hbb.draw(ctx, canvas)?;
+ }
+ }
+ }
+
+ if style.draw_mask_obbs() {
+ if let Some(polygon) = self.polygon() {
+ if let Some(obb) = polygon.obb() {
+ obb.draw(ctx, canvas)?;
+ }
+ }
+ }
+
if style.visible() {
let (w, h) = canvas.dimensions();
let mask_dyn = render_mask(self, style.colormap256());
diff --git a/src/viz/styles.rs b/src/viz/styles.rs
index 1b79192..beecef4 100644
--- a/src/viz/styles.rs
+++ b/src/viz/styles.rs
@@ -16,6 +16,8 @@ pub struct Style {
thickness_threshold: f32, // For Hbb
draw_mask_polygons: bool, // For Masks
draw_mask_polygon_largest: bool, // For Masks
+ draw_mask_hbbs: bool, // For Masks
+ draw_mask_obbs: bool, // For Masks
text_loc: TextLoc, // For ALL
color: StyleColors, // For ALL
palette: Vec, // For ALL
@@ -41,6 +43,8 @@ impl Default for Style {
color_fill_alpha: None,
draw_mask_polygons: false,
draw_mask_polygon_largest: false,
+ draw_mask_hbbs: false,
+ draw_mask_obbs: false,
radius: 3,
text_x_pos: 0.05,
text_y_pos: 0.05,