Add RMBG model (#90)

This commit is contained in:
Jamjamjon
2025-05-14 21:48:56 +08:00
committed by GitHub
parent 675dd63734
commit 57cb1ac77a
17 changed files with 283 additions and 31 deletions

9
examples/rmbg/README.md Normal file
View File

@@ -0,0 +1,9 @@
## Quick Start
```shell
cargo run -r --example rmbg -- --ver 1.4 --dtype fp16
```
## Results
![](https://github.com/jamjamjon/assets/releases/download/rmbg/demo.jpg)

58
examples/rmbg/main.rs Normal file
View File

@@ -0,0 +1,58 @@
use usls::{models::RMBG, Annotator, DataLoader, Options};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// dtype
#[argh(option, default = "String::from(\"auto\")")]
dtype: String,
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
/// version
#[argh(option, default = "1.4")]
ver: f32,
}
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
let options = match args.ver {
1.4 => Options::rmbg1_4(),
2.0 => Options::rmbg2_0(),
_ => unreachable!("Unsupported version"),
};
// build model
let options = options
.with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut model = RMBG::new(options)?;
// load image
let xs = DataLoader::try_read_n(&["./assets/cat.png"])?;
// run
let ys = model.forward(&xs)?;
// annotate
let annotator = Annotator::default();
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None))
.display(),
))?;
}
Ok(())
}

View File

@@ -1,7 +1,7 @@
## Quick Start
```shell
cargo run -r -F cuda --example yolo-sam -- --device cuda
cargo run -r -F cuda --example yolo-sam2 -- --device cuda
```
## Results

View File

@@ -39,15 +39,11 @@ fn main() -> Result<()> {
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// build annotator
let annotator = Annotator::default()
.with_polygon_style(
Style::polygon()
.with_visible(true)
.with_text_visible(true)
.show_id(true)
.show_name(true),
)
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
let annotator = Annotator::default().with_mask_style(
Style::mask()
.with_draw_mask_polygon_largest(true)
.with_draw_mask_hbbs(true),
);
// run & annotate
let ys_det = yolo.forward(&xs)?;