Add SAM2.1 models and support batched prompt inputs (#89)

This commit is contained in:
Jamjamjon
2025-05-13 17:28:25 +08:00
committed by GitHub
parent c54775cedd
commit 675dd63734
14 changed files with 552 additions and 141 deletions

View File

@ -16,7 +16,7 @@ struct Args {
scale: String,
/// SAM kind
#[argh(option, default = "String::from(\"sam\")")]
#[argh(option, default = "String::from(\"samhq\")")]
kind: String,
}
@ -69,9 +69,19 @@ fn main() -> Result<()> {
// Prompt
let prompts = vec![
SamPrompt::default()
// .with_postive_point(500., 375.), // postive point
// .with_negative_point(774., 366.), // negative point
.with_bbox(215., 297., 643., 459.), // bbox
// // # demo: point + point
// .with_positive_point(500., 375.) // mid window
// .with_positive_point(1125., 625.), // car door
// // # demo: bbox
// .with_xyxy(425., 600., 700., 875.), // left wheel
// // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored.
// .with_xyxy(75., 275., 1725., 850.)
// .with_xyxy(425., 600., 700., 875.)
// .with_xyxy(1240., 675., 1400., 750.)
// .with_xyxy(1375., 550., 1650., 800.)
// # demo: bbox + negative point
.with_xyxy(425., 600., 700., 875.) // left wheel
.with_negative_point(575., 750.), // tire
];
// Run & Annotate

6
examples/sam2/README.md Normal file
View File

@ -0,0 +1,6 @@
## Quick Start
```Shell
cargo run -r -F cuda --example sam -- --device cuda --scale t
```

93
examples/sam2/main.rs Normal file
View File

@ -0,0 +1,93 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2},
Annotator, DataLoader, Options, Scale,
};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
/// scale
#[argh(option, default = "String::from(\"t\")")]
scale: String,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// Build model
let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
Scale::T => (
Options::sam2_1_tiny_encoder(),
Options::sam2_1_tiny_decoder(),
),
Scale::S => (
Options::sam2_1_small_encoder(),
Options::sam2_1_small_decoder(),
),
Scale::B => (
Options::sam2_1_base_plus_encoder(),
Options::sam2_1_base_plus_decoder(),
),
Scale::L => (
Options::sam2_1_large_encoder(),
Options::sam2_1_large_decoder(),
),
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
};
let options_encoder = options_encoder
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let options_decoder = options_decoder
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut model = SAM2::new(options_encoder, options_decoder)?;
// Load image
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
// Prompt
let prompts = vec![SamPrompt::default()
// // # demo: point + point
// .with_positive_point(500., 375.) // mid window
// .with_positive_point(1125., 625.), // car door
// // # demo: bbox
// .with_xyxy(425., 600., 700., 875.), // left wheel
// // # demo: bbox + negative point
// .with_xyxy(425., 600., 700., 875.) // left wheel
// .with_negative_point(575., 750.), // tire
// # demo: multiple objects with boxes
.with_xyxy(75., 275., 1725., 850.)
.with_xyxy(425., 600., 700., 875.)
.with_xyxy(1375., 550., 1650., 800.)
.with_xyxy(1240., 675., 1400., 750.)];
// Run & Annotate
let ys = model.forward(&xs, &prompts)?;
// annotate
let annotator = Annotator::default()
.with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true));
for (x, y) in xs.iter().zip(ys.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None))
.display(),
))?;
}
Ok(())
}

View File

@ -1,73 +0,0 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM, YOLO},
Annotator, DataLoader, Options, Scale, Style,
};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// build SAM
let (options_encoder, options_decoder) = (
Options::mobile_sam_tiny_encoder().commit()?,
Options::mobile_sam_tiny_decoder().commit()?,
);
let mut sam = SAM::new(options_encoder, options_decoder)?;
// build YOLOv8
let options_yolo = Options::yolo_detect()
.with_model_scale(Scale::N)
.with_model_version(8.into())
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut yolo = YOLO::new(options_yolo)?;
// load one image
let xs = DataLoader::try_read_n(&["images/dog.jpg"])?;
// build annotator
let annotator = Annotator::default().with_hbb_style(Style::hbb().with_draw_fill(true));
// run & annotate
let ys_det = yolo.forward(&xs)?;
for y_det in ys_det.iter() {
if let Some(hbbs) = y_det.hbbs() {
for hbb in hbbs {
let ys_sam = sam.forward(
&xs,
&[SamPrompt::default().with_bbox(
hbb.xmin(),
hbb.ymin(),
hbb.xmax(),
hbb.ymax(),
)],
)?;
// annotator.annotate(&xs, &ys_sam);
for (x, y) in xs.iter().zip(ys_sam.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", "YOLO-SAM"])?
.join(usls::timestamp(None))
.display(),
))?;
}
}
}
}
Ok(())
}

View File

@ -0,0 +1,79 @@
use anyhow::Result;
use usls::{
models::{SamPrompt, SAM2, YOLO},
Annotator, DataLoader, Options, Scale, Style,
};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// build SAM
let (options_encoder, options_decoder) = (
Options::sam2_1_tiny_encoder().commit()?,
Options::sam2_1_tiny_decoder().commit()?,
);
let mut sam = SAM2::new(options_encoder, options_decoder)?;
// build YOLOv8
let options_yolo = Options::yolo_detect()
.with_model_scale(Scale::N)
.with_model_version(8.into())
.with_model_device(args.device.as_str().try_into()?)
.commit()?;
let mut yolo = YOLO::new(options_yolo)?;
// load one image
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
// build annotator
let annotator = Annotator::default()
.with_polygon_style(
Style::polygon()
.with_visible(true)
.with_text_visible(true)
.show_id(true)
.show_name(true),
)
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
// run & annotate
let ys_det = yolo.forward(&xs)?;
for y_det in ys_det.iter() {
if let Some(hbbs) = y_det.hbbs() {
// collect hhbs
let mut prompt = SamPrompt::default();
for hbb in hbbs {
prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
}
// sam2 infer
let ys_sam = sam.forward(&xs, &[prompt])?;
// annotate
for (x, y) in xs.iter().zip(ys_sam.iter()) {
annotator.annotate(x, y)?.save(format!(
"{}.jpg",
usls::Dir::Current
.base_dir_with_subs(&["runs", "YOLO-SAM2"])?
.join(usls::timestamp(None))
.display(),
))?;
}
}
}
Ok(())
}

View File

@ -24,6 +24,7 @@ mod rfdetr;
mod rtdetr;
mod rtmo;
mod sam;
mod sam2;
mod sapiens;
mod slanet;
mod smolvlm;
@ -49,6 +50,7 @@ pub use rfdetr::*;
pub use rtdetr::*;
pub use rtmo::*;
pub use sam::*;
pub use sam2::*;
pub use sapiens::*;
pub use slanet::*;
pub use smolvlm::*;

View File

@ -1,16 +1,17 @@
use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Array, Axis};
use ndarray::{s, Axis};
use rand::prelude::*;
use crate::{
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, Ts, Xs, X, Y,
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X,
Y,
};
#[derive(Debug, Clone)]
pub enum SamKind {
Sam,
Sam2,
Sam2, // 2.0
MobileSam,
SamHq,
EdgeSam,
@ -31,54 +32,6 @@ impl TryFrom<&str> for SamKind {
}
}
#[derive(Debug, Default, Clone)]
pub struct SamPrompt {
points: Vec<f32>,
labels: Vec<f32>,
}
impl SamPrompt {
pub fn everything() -> Self {
todo!()
}
pub fn with_postive_point(mut self, x: f32, y: f32) -> Self {
self.points.extend_from_slice(&[x, y]);
self.labels.push(1.);
self
}
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
self.points.extend_from_slice(&[x, y]);
self.labels.push(0.);
self
}
pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self {
self.points.extend_from_slice(&[x, y, x2, y2]);
self.labels.extend_from_slice(&[2., 3.]);
self
}
pub fn point_coords(&self, r: f32) -> Result<X> {
let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())?
.into_dyn()
.into_owned();
Ok(X::from(point_coords * r))
}
pub fn point_labels(&self) -> Result<X> {
let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())?
.into_dyn()
.into_owned();
Ok(X::from(point_labels))
}
pub fn num_points(&self) -> usize {
self.points.len() / 2
}
}
#[derive(Builder, Debug)]
pub struct SAM {
encoder: Engine,
@ -167,14 +120,28 @@ impl SAM {
);
let ratio = self.processor.images_transform_info[idx].height_scale;
let (mut point_coords, mut point_labels) = (
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
);
if point_coords.shape()[0] != 1 {
point_coords = X::from(point_coords.slice(s![-1, .., ..]).to_owned().into_dyn())
.insert_axis(0)?;
}
if point_labels.shape()[0] != 1 {
point_labels = X::from(point_labels.slice(s![-1, ..,]).to_owned().into_dyn())
.insert_axis(0)?;
}
let args = match self.kind {
SamKind::Sam | SamKind::MobileSam => {
vec![
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch)?, // image_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
point_coords,
point_labels,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height as _, image_width as _]), // orig_im_size
@ -189,8 +156,8 @@ impl SAM {
.insert_axis(0)?
.insert_axis(0)?
.repeat(0, self.batch)?, // intern_embedding
prompts[idx].point_coords(ratio)?, // point_coords
prompts[idx].point_labels()?, // point_labels
point_coords,
point_labels,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height as _, image_width as _]), // orig_im_size
@ -201,8 +168,8 @@ impl SAM {
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
point_coords,
point_labels,
]
}
SamKind::Sam2 => {
@ -228,11 +195,11 @@ impl SAM {
)
.insert_axis(0)?
.repeat(0, self.batch)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height as _, image_width as _]), // orig_im_size
point_coords,
point_labels,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
X::zeros(&[1]),
X::from(vec![image_height as _, image_width as _]),
]
}
};

View File

@ -2,3 +2,88 @@ mod config;
mod r#impl;
pub use r#impl::*;
#[derive(Debug, Default, Clone)]
pub struct SamPrompt {
pub coords: Vec<Vec<[f32; 2]>>,
pub labels: Vec<Vec<f32>>,
}
impl SamPrompt {
pub fn point_coords(&self, ratio: f32) -> anyhow::Result<crate::X> {
// [num_labels,num_points,2]
let num_labels = self.coords.len();
let num_points = if num_labels > 0 {
self.coords[0].len()
} else {
0
};
let flat: Vec<f32> = self
.coords
.iter()
.flat_map(|v| v.iter().flat_map(|&[x, y]| [x, y]))
.collect();
let y = ndarray::Array3::from_shape_vec((num_labels, num_points, 2), flat)?.into_dyn();
Ok((y * ratio).into())
}
pub fn point_labels(&self) -> anyhow::Result<crate::X> {
// [num_labels,num_points]
let num_labels = self.labels.len();
let num_points = if num_labels > 0 {
self.labels[0].len()
} else {
0
};
let flat: Vec<f32> = self.labels.iter().flat_map(|v| v.iter().copied()).collect();
let y = ndarray::Array2::from_shape_vec((num_labels, num_points), flat)?.into_dyn();
Ok(y.into())
}
pub fn with_xyxy(mut self, x1: f32, y1: f32, x2: f32, y2: f32) -> Self {
// TODO: if already has points, push_front coords
self.coords.push(vec![[x1, y1], [x2, y2]]);
self.labels.push(vec![2., 3.]);
self
}
pub fn with_positive_point(mut self, x: f32, y: f32) -> Self {
self = self.add_point(x, y, 1.);
self
}
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
self = self.add_point(x, y, 0.);
self
}
fn add_point(mut self, x: f32, y: f32, id: f32) -> Self {
if self.coords.is_empty() {
self.coords.push(vec![[x, y]]);
self.labels.push(vec![id]);
} else {
if let Some(last) = self.coords.last_mut() {
last.extend_from_slice(&[[x, y]]);
}
if let Some(last) = self.labels.last_mut() {
last.extend_from_slice(&[id]);
}
}
self
}
pub fn with_positive_point_object(mut self, x: f32, y: f32) -> Self {
self.coords.push(vec![[x, y]]);
self.labels.push(vec![1.]);
self
}
pub fn with_negative_point_object(mut self, x: f32, y: f32) -> Self {
self.coords.push(vec![[x, y]]);
self.labels.push(vec![0.]);
self
}
}

10
src/models/sam2/README.md Normal file
View File

@ -0,0 +1,10 @@
# Segment Anything Model
## Official Repository
The official repository can be found on [sam2](https://github.com/facebookresearch/sam2)
## Example
Refer to the [example](../../../examples/sam2)

50
src/models/sam2/config.rs Normal file
View File

@ -0,0 +1,50 @@
use crate::Options;
/// Model configuration for `SAM2.1`
impl Options {
pub fn sam2_encoder() -> Self {
Self::sam()
.with_model_ixx(0, 2, 1024.into())
.with_model_ixx(0, 3, 1024.into())
.with_resize_mode(crate::ResizeMode::FitAdaptive)
.with_resize_filter("Bilinear")
.with_image_mean(&[0.485, 0.456, 0.406])
.with_image_std(&[0.229, 0.224, 0.225])
}
pub fn sam2_decoder() -> Self {
Self::sam()
}
pub fn sam2_1_tiny_encoder() -> Self {
Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx")
}
pub fn sam2_1_tiny_decoder() -> Self {
Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx")
}
pub fn sam2_1_small_encoder() -> Self {
Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx")
}
pub fn sam2_1_small_decoder() -> Self {
Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx")
}
pub fn sam2_1_base_plus_encoder() -> Self {
Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx")
}
pub fn sam2_1_base_plus_decoder() -> Self {
Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx")
}
pub fn sam2_1_large_encoder() -> Self {
Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx")
}
pub fn sam2_1_large_decoder() -> Self {
Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx")
}
}

164
src/models/sam2/impl.rs Normal file
View File

@ -0,0 +1,164 @@
use aksr::Builder;
use anyhow::Result;
use ndarray::{s, Axis};
use crate::{
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y,
};
#[derive(Builder, Debug)]
pub struct SAM2 {
encoder: Engine,
decoder: Engine,
height: usize,
width: usize,
batch: usize,
processor: Processor,
conf: DynConf,
ts: Ts,
spec: String,
}
impl SAM2 {
pub fn new(options_encoder: Options, options_decoder: Options) -> Result<Self> {
let encoder = options_encoder.to_engine()?;
let decoder = options_decoder.to_engine()?;
let (batch, height, width) = (
encoder.batch().opt(),
encoder.try_height().unwrap_or(&1024.into()).opt(),
encoder.try_width().unwrap_or(&1024.into()).opt(),
);
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
let spec = encoder.spec().to_owned();
let processor = options_encoder
.to_processor()?
.with_image_width(width as _)
.with_image_height(height as _);
let conf = DynConf::new(options_encoder.class_confs(), 1);
Ok(Self {
encoder,
decoder,
conf,
batch,
height,
width,
ts,
processor,
spec,
})
}
pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let ys = elapsed!("encode", self.ts, { self.encode(xs)? });
let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? });
Ok(ys)
}
pub fn encode(&mut self, xs: &[Image]) -> Result<Xs> {
let xs_ = self.processor.process_images(xs)?;
self.encoder.run(Xs::from(xs_))
}
pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result<Vec<Y>> {
let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]);
let mut ys: Vec<Y> = Vec::new();
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
let (image_height, image_width) = (
self.processor.images_transform_info[idx].height_src,
self.processor.images_transform_info[idx].width_src,
);
let ratio = self.processor.images_transform_info[idx].height_scale;
let ys_ = self.decoder.run(Xs::from(vec![
X::from(image_embedding.into_dyn().into_owned())
.insert_axis(0)?
.repeat(0, self.batch)?,
X::from(
high_res_features_0
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?
.repeat(0, self.batch)?,
X::from(
high_res_features_1
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?
.repeat(0, self.batch)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
// TODO
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
X::zeros(&[1]),
X::from(vec![self.width as _, self.height as _]),
]))?;
let mut y_masks: Vec<Mask> = Vec::new();
// masks & confs
let (masks, confs) = (&ys_[0], &ys_[1]);
for (id, (mask, iou)) in masks
.axis_iter(Axis(0))
.zip(confs.axis_iter(Axis(0)))
.enumerate()
{
let (i, conf) = match iou
.to_owned()
.into_raw_vec_and_offset()
.0
.into_iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(&b.1))
{
Some((i, c)) => (i, c),
None => continue,
};
if conf < self.conf[0] {
continue;
}
let mask = mask.slice(s![i, .., ..]);
let (h, w) = mask.dim();
let luma = Ops::resize_lumaf32_u8(
&mask.into_owned().into_raw_vec_and_offset().0,
w as _,
h as _,
image_width as _,
image_height as _,
true,
"Bilinear",
)?;
// contours
let mask = Mask::new(&luma, image_width, image_height)?.with_id(id);
y_masks.push(mask);
}
let mut y = Y::default();
if !y_masks.is_empty() {
y = y.with_masks(&y_masks);
}
ys.push(y);
}
Ok(ys)
}
pub fn width_low_res(&self) -> usize {
self.width / 4
}
pub fn height_low_res(&self) -> usize {
self.height / 4
}
}

4
src/models/sam2/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod config;
mod r#impl;
pub use r#impl::*;

View File

@ -27,6 +27,14 @@ pub struct Options {
pub trt_fp16: bool,
pub profile: bool,
// models
pub model_encoder_file: Option<String>,
pub model_decoder_file: Option<String>,
pub visual_encoder_file: Option<String>,
pub visual_decoder_file: Option<String>,
pub textual_encoder_file: Option<String>,
pub textual_decoder_file: Option<String>,
// Processor configs
#[args(except(setter))]
pub image_width: u32,
@ -113,8 +121,8 @@ pub struct Options {
pub binary_thresh: Option<f32>,
// For SAM
pub sam_kind: Option<SamKind>,
pub low_res_mask: Option<bool>,
pub sam_kind: Option<SamKind>, // TODO: remove
pub low_res_mask: Option<bool>, // TODO: remove
// Others
pub ort_graph_opt_level: Option<u8>,
@ -203,6 +211,12 @@ impl Default for Options {
topk_2: None,
topk_3: None,
ort_graph_opt_level: None,
model_encoder_file: None,
model_decoder_file: None,
visual_encoder_file: None,
visual_decoder_file: None,
textual_encoder_file: None,
textual_decoder_file: None,
}
}
}