mirror of
https://github.com/mii443/usls.git
synced 2025-12-03 02:58:22 +00:00
Add SAM2.1 models and support batched prompt inputs (#89)
This commit is contained in:
@@ -16,7 +16,7 @@ struct Args {
|
||||
scale: String,
|
||||
|
||||
/// SAM kind
|
||||
#[argh(option, default = "String::from(\"sam\")")]
|
||||
#[argh(option, default = "String::from(\"samhq\")")]
|
||||
kind: String,
|
||||
}
|
||||
|
||||
@@ -69,9 +69,19 @@ fn main() -> Result<()> {
|
||||
// Prompt
|
||||
let prompts = vec![
|
||||
SamPrompt::default()
|
||||
// .with_postive_point(500., 375.), // postive point
|
||||
// .with_negative_point(774., 366.), // negative point
|
||||
.with_bbox(215., 297., 643., 459.), // bbox
|
||||
// // # demo: point + point
|
||||
// .with_positive_point(500., 375.) // mid window
|
||||
// .with_positive_point(1125., 625.), // car door
|
||||
// // # demo: bbox
|
||||
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||
// // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored.
|
||||
// .with_xyxy(75., 275., 1725., 850.)
|
||||
// .with_xyxy(425., 600., 700., 875.)
|
||||
// .with_xyxy(1240., 675., 1400., 750.)
|
||||
// .with_xyxy(1375., 550., 1650., 800.)
|
||||
// # demo: bbox + negative point
|
||||
.with_xyxy(425., 600., 700., 875.) // left wheel
|
||||
.with_negative_point(575., 750.), // tire
|
||||
];
|
||||
|
||||
// Run & Annotate
|
||||
|
||||
6
examples/sam2/README.md
Normal file
6
examples/sam2/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
## Quick Start
|
||||
|
||||
```Shell
|
||||
|
||||
cargo run -r -F cuda --example sam -- --device cuda --scale t
|
||||
```
|
||||
93
examples/sam2/main.rs
Normal file
93
examples/sam2/main.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM2},
|
||||
Annotator, DataLoader, Options, Scale,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
|
||||
/// scale
|
||||
#[argh(option, default = "String::from(\"t\")")]
|
||||
scale: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// Build model
|
||||
let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
|
||||
Scale::T => (
|
||||
Options::sam2_1_tiny_encoder(),
|
||||
Options::sam2_1_tiny_decoder(),
|
||||
),
|
||||
Scale::S => (
|
||||
Options::sam2_1_small_encoder(),
|
||||
Options::sam2_1_small_decoder(),
|
||||
),
|
||||
Scale::B => (
|
||||
Options::sam2_1_base_plus_encoder(),
|
||||
Options::sam2_1_base_plus_decoder(),
|
||||
),
|
||||
Scale::L => (
|
||||
Options::sam2_1_large_encoder(),
|
||||
Options::sam2_1_large_decoder(),
|
||||
),
|
||||
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
|
||||
};
|
||||
|
||||
let options_encoder = options_encoder
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let options_decoder = options_decoder
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut model = SAM2::new(options_encoder, options_decoder)?;
|
||||
|
||||
// Load image
|
||||
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
|
||||
|
||||
// Prompt
|
||||
let prompts = vec![SamPrompt::default()
|
||||
// // # demo: point + point
|
||||
// .with_positive_point(500., 375.) // mid window
|
||||
// .with_positive_point(1125., 625.), // car door
|
||||
// // # demo: bbox
|
||||
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||
// // # demo: bbox + negative point
|
||||
// .with_xyxy(425., 600., 700., 875.) // left wheel
|
||||
// .with_negative_point(575., 750.), // tire
|
||||
// # demo: multiple objects with boxes
|
||||
.with_xyxy(75., 275., 1725., 850.)
|
||||
.with_xyxy(425., 600., 700., 875.)
|
||||
.with_xyxy(1375., 550., 1650., 800.)
|
||||
.with_xyxy(1240., 675., 1400., 750.)];
|
||||
|
||||
// Run & Annotate
|
||||
let ys = model.forward(&xs, &prompts)?;
|
||||
|
||||
// annotate
|
||||
let annotator = Annotator::default()
|
||||
.with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true));
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", model.spec()])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM, YOLO},
|
||||
Annotator, DataLoader, Options, Scale, Style,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build SAM
|
||||
let (options_encoder, options_decoder) = (
|
||||
Options::mobile_sam_tiny_encoder().commit()?,
|
||||
Options::mobile_sam_tiny_decoder().commit()?,
|
||||
);
|
||||
let mut sam = SAM::new(options_encoder, options_decoder)?;
|
||||
|
||||
// build YOLOv8
|
||||
let options_yolo = Options::yolo_detect()
|
||||
.with_model_scale(Scale::N)
|
||||
.with_model_version(8.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut yolo = YOLO::new(options_yolo)?;
|
||||
|
||||
// load one image
|
||||
let xs = DataLoader::try_read_n(&["images/dog.jpg"])?;
|
||||
|
||||
// build annotator
|
||||
let annotator = Annotator::default().with_hbb_style(Style::hbb().with_draw_fill(true));
|
||||
|
||||
// run & annotate
|
||||
let ys_det = yolo.forward(&xs)?;
|
||||
for y_det in ys_det.iter() {
|
||||
if let Some(hbbs) = y_det.hbbs() {
|
||||
for hbb in hbbs {
|
||||
let ys_sam = sam.forward(
|
||||
&xs,
|
||||
&[SamPrompt::default().with_bbox(
|
||||
hbb.xmin(),
|
||||
hbb.ymin(),
|
||||
hbb.xmax(),
|
||||
hbb.ymax(),
|
||||
)],
|
||||
)?;
|
||||
// annotator.annotate(&xs, &ys_sam);
|
||||
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", "YOLO-SAM"])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
79
examples/yolo-sam2/main.rs
Normal file
79
examples/yolo-sam2/main.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM2, YOLO},
|
||||
Annotator, DataLoader, Options, Scale, Style,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build SAM
|
||||
let (options_encoder, options_decoder) = (
|
||||
Options::sam2_1_tiny_encoder().commit()?,
|
||||
Options::sam2_1_tiny_decoder().commit()?,
|
||||
);
|
||||
let mut sam = SAM2::new(options_encoder, options_decoder)?;
|
||||
|
||||
// build YOLOv8
|
||||
let options_yolo = Options::yolo_detect()
|
||||
.with_model_scale(Scale::N)
|
||||
.with_model_version(8.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut yolo = YOLO::new(options_yolo)?;
|
||||
|
||||
// load one image
|
||||
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
|
||||
|
||||
// build annotator
|
||||
let annotator = Annotator::default()
|
||||
.with_polygon_style(
|
||||
Style::polygon()
|
||||
.with_visible(true)
|
||||
.with_text_visible(true)
|
||||
.show_id(true)
|
||||
.show_name(true),
|
||||
)
|
||||
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
|
||||
|
||||
// run & annotate
|
||||
let ys_det = yolo.forward(&xs)?;
|
||||
for y_det in ys_det.iter() {
|
||||
if let Some(hbbs) = y_det.hbbs() {
|
||||
// collect hhbs
|
||||
let mut prompt = SamPrompt::default();
|
||||
for hbb in hbbs {
|
||||
prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
|
||||
}
|
||||
|
||||
// sam2 infer
|
||||
let ys_sam = sam.forward(&xs, &[prompt])?;
|
||||
|
||||
// annotate
|
||||
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", "YOLO-SAM2"])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user