Add RMBG model (#90)

This commit is contained in:
Jamjamjon
2025-05-14 21:48:56 +08:00
committed by GitHub
parent 675dd63734
commit 57cb1ac77a
17 changed files with 283 additions and 31 deletions

View File

@ -116,6 +116,7 @@
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | | | [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | | | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | | | [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
| [RMBG(1.4, 2.0)](https://huggingface.co/briaai/RMBG-2.0) | Image Segmentation Answering | [demo](examples/rmbg) | ✅ | ✅ | ✅ | | |
</details> </details>

9
examples/rmbg/README.md Normal file
View File

@ -0,0 +1,9 @@
## Quick Start
```shell
cargo run -r --example rmbg -- --ver 1.4 --dtype fp16
```
## Results
![](https://github.com/jamjamjon/assets/releases/download/rmbg/demo.jpg)

58
examples/rmbg/main.rs Normal file
View File

@ -0,0 +1,58 @@
use usls::{models::RMBG, Annotator, DataLoader, Options};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
#[argh(option, default = "String::from(\"auto\")")]
dtype: String,
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
/// version
#[argh(option, default = "1.4")]
ver: f32,
}
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
let options = match args.ver {
1.4 => Options::rmbg1_4(),
2.0 => Options::rmbg2_0(),
_ => unreachable!("Unsupported version"),
};
// build model
let options = options
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut model = RMBG::new(options)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/cat.png"])?;
// run
let ys = model.forward(&xs)?;
// annotate
let annotator = Annotator::default();
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None))
.display(),
))?;
}
Ok(())
}

View File

@ -1,7 +1,7 @@
## Quick Start ## Quick Start
```shell ```shell
cargo run -r -F cuda --example yolo-sam -- --device cuda cargo run -r -F cuda --example yolo-sam2 -- --device cuda
``` ```
## Results ## Results

View File

@ -39,15 +39,11 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// build annotator // build annotator
let annotator = Annotator::default() let annotator = Annotator::default().with_mask_style(
.with_polygon_style( Style::mask()
Style::polygon() .with_draw_mask_polygon_largest(true)
.with_visible(true) .with_draw_mask_hbbs(true),
.with_text_visible(true) );
.show_id(true)
.show_name(true),
)
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
// run & annotate // run & annotate
let ys_det = yolo.forward(&xs)?; let ys_det = yolo.forward(&xs)?;

View File

@ -123,12 +123,23 @@ impl Polygon {
pub fn hbb(&self) -> Option<Hbb> { pub fn hbb(&self) -> Option<Hbb> {
use geo::BoundingRect; use geo::BoundingRect;
self.polygon.bounding_rect().map(|x| { self.polygon.bounding_rect().map(|x| {
Hbb::default().with_xyxy( let mut hbb = Hbb::default().with_xyxy(
x.min().x as f32, x.min().x as f32,
x.min().y as f32, x.min().y as f32,
x.max().x as f32, x.max().x as f32,
x.max().y as f32, x.max().y as f32,
) );
if let Some(id) = self.id() {
hbb = hbb.with_id(id);
}
if let Some(name) = self.name() {
hbb = hbb.with_name(name);
}
if let Some(confidence) = self.confidence() {
hbb = hbb.with_confidence(confidence);
}
hbb
}) })
} }
@ -138,11 +149,22 @@ impl Polygon {
let xy4 = x let xy4 = x
.exterior() .exterior()
.coords() .coords()
// .iter()
.map(|c| [c.x as f32, c.y as f32]) .map(|c| [c.x as f32, c.y as f32])
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Obb::from(xy4) let mut obb = Obb::from(xy4);
if let Some(id) = self.id() {
obb = obb.with_id(id);
}
if let Some(name) = self.name() {
obb = obb.with_name(name);
}
if let Some(confidence) = self.confidence() {
obb = obb.with_confidence(confidence);
}
obb
}) })
} }

View File

@ -75,12 +75,7 @@ impl DepthAnything {
false, false,
"Bilinear", "Bilinear",
)?; )?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> = ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
} }
Ok(ys) Ok(ys)

View File

@ -76,12 +76,7 @@ impl DepthPro {
false, false,
"Bilinear", "Bilinear",
)?; )?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> = ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
} }
Ok(ys) Ok(ys)

View File

@ -19,4 +19,8 @@ impl crate::Options {
pub fn grounding_dino_tiny() -> Self { pub fn grounding_dino_tiny() -> Self {
Self::grounding_dino().with_model_file("swint-ogc.onnx") Self::grounding_dino().with_model_file("swint-ogc.onnx")
} }
pub fn grounding_dino_base() -> Self {
Self::grounding_dino().with_model_file("swinb-cogcoor.onnx")
}
} }

View File

@ -21,6 +21,7 @@ mod owl;
mod picodet; mod picodet;
mod pipeline; mod pipeline;
mod rfdetr; mod rfdetr;
mod rmbg;
mod rtdetr; mod rtdetr;
mod rtmo; mod rtmo;
mod sam; mod sam;
@ -47,6 +48,7 @@ pub use owl::*;
pub use picodet::*; pub use picodet::*;
pub use pipeline::*; pub use pipeline::*;
pub use rfdetr::*; pub use rfdetr::*;
pub use rmbg::*;
pub use rtdetr::*; pub use rtdetr::*;
pub use rtmo::*; pub use rtmo::*;
pub use sam::*; pub use sam::*;

View File

@ -0,0 +1,9 @@
# RMBG: BRIA Background Removal
## Official Repository
The official repository can be found on: [HuggingFace](https://huggingface.co/briaai/RMBG-2.0)
## Example
Refer to the [example](../../../examples/rmbg)

24
src/models/rmbg/config.rs Normal file
View File

@ -0,0 +1,24 @@
/// Model configuration for `RMBG-2.0`
impl crate::Options {
pub fn rmbg() -> Self {
Self::default()
.with_model_name("rmbg")
.with_model_ixx(0, 0, 1.into())
.with_model_ixx(0, 2, 1024.into())
.with_model_ixx(0, 3, 1024.into())
}
pub fn rmbg1_4() -> Self {
Self::rmbg()
.with_image_mean(&[0.5, 0.5, 0.5])
.with_image_std(&[1., 1., 1.])
.with_model_file("1.4.onnx")
}
pub fn rmbg2_0() -> Self {
Self::rmbg()
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
.with_model_file("2.0.onnx")
}
}

93
src/models/rmbg/impl.rs Normal file
View File

@ -0,0 +1,93 @@
use aksr::Builder;
use anyhow::Result;
use crate::{elapsed, Engine, Image, Mask, Ops, Options, Processor, Ts, Xs, Y};
#[derive(Builder, Debug)]
pub struct RMBG {
engine: Engine,
height: usize,
width: usize,
batch: usize,
ts: Ts,
spec: String,
processor: Processor,
}
impl RMBG {
pub fn new(options: Options) -> Result<Self> {
let engine = options.to_engine()?;
let spec = engine.spec().to_string();
let (batch, height, width, ts) = (
engine.batch().opt(),
engine.try_height().unwrap_or(&1024.into()).opt(),
engine.try_width().unwrap_or(&1024.into()).opt(),
engine.ts().clone(),
);
let processor = options
.to_processor()?
.with_image_width(width as _)
.with_image_height(height as _);
Ok(Self {
engine,
height,
width,
batch,
ts,
spec,
processor,
})
}
fn preprocess(&mut self, xs: &[Image]) -> Result<Xs> {
Ok(self.processor.process_images(xs)?.into())
}
fn inference(&mut self, xs: Xs) -> Result<Xs> {
self.engine.run(xs)
}
pub fn forward(&mut self, xs: &[Image]) -> Result<Vec<Y>> {
let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? });
let ys = elapsed!("inference", self.ts, { self.inference(ys)? });
let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? });
Ok(ys)
}
pub fn summary(&mut self) {
self.ts.summary();
}
fn postprocess(&mut self, xs: Xs) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in xs[0].axis_iter(ndarray::Axis(0)).enumerate() {
// image size
let (h1, w1) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
);
let v = luma.into_owned().into_raw_vec_and_offset().0;
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
let v = v
.iter()
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_u8(
&v,
self.width() as _,
self.height() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
ys.push(Y::default().with_masks(&[Mask::new(&luma, w1, h1)?]));
}
Ok(ys)
}
}

4
src/models/rmbg/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod config;
mod r#impl;
pub use r#impl::*;

View File

@ -34,7 +34,10 @@ impl Drawable for Hbb {
imageproc::drawing::draw_filled_rect_mut( imageproc::drawing::draw_filled_rect_mut(
&mut overlay, &mut overlay,
imageproc::rect::Rect::at(self.xmin().round() as i32, self.ymin().round() as i32) imageproc::rect::Rect::at(self.xmin().round() as i32, self.ymin().round() as i32)
.of_size(self.width().round() as u32, self.height().round() as u32), .of_size(
(self.width().round() as u32).max(1),
(self.height().round() as u32).max(1),
),
Rgba(style.color().fill.unwrap().into()), Rgba(style.color().fill.unwrap().into()),
); );
image::imageops::overlay(canvas, &overlay, 0, 0); image::imageops::overlay(canvas, &overlay, 0, 0);
@ -43,7 +46,7 @@ impl Drawable for Hbb {
if style.draw_outline() { if style.draw_outline() {
let short_side_threshold = let short_side_threshold =
self.width().min(self.height()) * style.thickness_threshold(); self.width().min(self.height()) * style.thickness_threshold();
let thickness = style.thickness().min(short_side_threshold as usize); let thickness = style.thickness().min(short_side_threshold as usize).max(1);
for i in 0..thickness { for i in 0..thickness {
imageproc::drawing::draw_hollow_rect_mut( imageproc::drawing::draw_hollow_rect_mut(
canvas, canvas,
@ -52,8 +55,8 @@ impl Drawable for Hbb {
(self.ymin().round() as i32) - (i as i32), (self.ymin().round() as i32) - (i as i32),
) )
.of_size( .of_size(
(self.width().round() as u32) + (2 * i as u32), ((self.width().round() as u32) + (2 * i as u32)).max(1),
(self.height().round() as u32) + (2 * i as u32), ((self.height().round() as u32) + (2 * i as u32)).max(1),
), ),
Rgba(style.color().outline.unwrap().into()), Rgba(style.color().outline.unwrap().into()),
); );
@ -78,7 +81,7 @@ impl Drawable for Hbb {
if style.draw_text() { if style.draw_text() {
let short_side_threshold = let short_side_threshold =
self.width().min(self.height()) * style.thickness_threshold(); self.width().min(self.height()) * style.thickness_threshold();
let thickness = style.thickness().min(short_side_threshold as usize); let thickness = style.thickness().min(short_side_threshold as usize).max(1);
let label = self.meta().label( let label = self.meta().label(
style.id(), style.id(),

View File

@ -89,12 +89,29 @@ impl Drawable for Vec<Mask> {
polygon.draw(ctx, canvas)?; polygon.draw(ctx, canvas)?;
} }
} }
if style.draw_mask_polygons() { if style.draw_mask_polygons() {
for polygon in mask.polygons() { for polygon in mask.polygons() {
polygon.draw(ctx, canvas)?; polygon.draw(ctx, canvas)?;
} }
} }
if style.draw_mask_hbbs() {
if let Some(polygon) = mask.polygon() {
if let Some(hbb) = polygon.hbb() {
hbb.draw(ctx, canvas)?;
}
}
}
if style.draw_mask_obbs() {
if let Some(polygon) = mask.polygon() {
if let Some(obb) = polygon.obb() {
obb.draw(ctx, canvas)?;
}
}
}
if style.visible() { if style.visible() {
masks_visible.push(mask); masks_visible.push(mask);
} }
@ -140,6 +157,22 @@ impl Drawable for Mask {
} }
} }
if style.draw_mask_hbbs() {
if let Some(polygon) = self.polygon() {
if let Some(hbb) = polygon.hbb() {
hbb.draw(ctx, canvas)?;
}
}
}
if style.draw_mask_obbs() {
if let Some(polygon) = self.polygon() {
if let Some(obb) = polygon.obb() {
obb.draw(ctx, canvas)?;
}
}
}
if style.visible() { if style.visible() {
let (w, h) = canvas.dimensions(); let (w, h) = canvas.dimensions();
let mask_dyn = render_mask(self, style.colormap256()); let mask_dyn = render_mask(self, style.colormap256());

View File

@ -16,6 +16,8 @@ pub struct Style {
thickness_threshold: f32, // For Hbb thickness_threshold: f32, // For Hbb
draw_mask_polygons: bool, // For Masks draw_mask_polygons: bool, // For Masks
draw_mask_polygon_largest: bool, // For Masks draw_mask_polygon_largest: bool, // For Masks
draw_mask_hbbs: bool, // For Masks
draw_mask_obbs: bool, // For Masks
text_loc: TextLoc, // For ALL text_loc: TextLoc, // For ALL
color: StyleColors, // For ALL color: StyleColors, // For ALL
palette: Vec<Color>, // For ALL palette: Vec<Color>, // For ALL
@ -41,6 +43,8 @@ impl Default for Style {
color_fill_alpha: None, color_fill_alpha: None,
draw_mask_polygons: false, draw_mask_polygons: false,
draw_mask_polygon_largest: false, draw_mask_polygon_largest: false,
draw_mask_hbbs: false,
draw_mask_obbs: false,
radius: 3, radius: 3,
text_x_pos: 0.05, text_x_pos: 0.05,
text_y_pos: 0.05, text_y_pos: 0.05,