mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Add SAM2.1 models and support batched prompt inputs (#89)
This commit is contained in:
@ -16,7 +16,7 @@ struct Args {
|
|||||||
scale: String,
|
scale: String,
|
||||||
|
|
||||||
/// SAM kind
|
/// SAM kind
|
||||||
#[argh(option, default = "String::from(\"sam\")")]
|
#[argh(option, default = "String::from(\"samhq\")")]
|
||||||
kind: String,
|
kind: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,9 +69,19 @@ fn main() -> Result<()> {
|
|||||||
// Prompt
|
// Prompt
|
||||||
let prompts = vec![
|
let prompts = vec![
|
||||||
SamPrompt::default()
|
SamPrompt::default()
|
||||||
// .with_postive_point(500., 375.), // postive point
|
// // # demo: point + point
|
||||||
// .with_negative_point(774., 366.), // negative point
|
// .with_positive_point(500., 375.) // mid window
|
||||||
.with_bbox(215., 297., 643., 459.), // bbox
|
// .with_positive_point(1125., 625.), // car door
|
||||||
|
// // # demo: bbox
|
||||||
|
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||||
|
// // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored.
|
||||||
|
// .with_xyxy(75., 275., 1725., 850.)
|
||||||
|
// .with_xyxy(425., 600., 700., 875.)
|
||||||
|
// .with_xyxy(1240., 675., 1400., 750.)
|
||||||
|
// .with_xyxy(1375., 550., 1650., 800.)
|
||||||
|
// # demo: bbox + negative point
|
||||||
|
.with_xyxy(425., 600., 700., 875.) // left wheel
|
||||||
|
.with_negative_point(575., 750.), // tire
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run & Annotate
|
// Run & Annotate
|
||||||
|
6
examples/sam2/README.md
Normal file
6
examples/sam2/README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```Shell
|
||||||
|
|
||||||
|
cargo run -r -F cuda --example sam -- --device cuda --scale t
|
||||||
|
```
|
93
examples/sam2/main.rs
Normal file
93
examples/sam2/main.rs
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use usls::{
|
||||||
|
models::{SamPrompt, SAM2},
|
||||||
|
Annotator, DataLoader, Options, Scale,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(argh::FromArgs)]
|
||||||
|
/// Example
|
||||||
|
struct Args {
|
||||||
|
/// device
|
||||||
|
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||||
|
device: String,
|
||||||
|
|
||||||
|
/// scale
|
||||||
|
#[argh(option, default = "String::from(\"t\")")]
|
||||||
|
scale: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let args: Args = argh::from_env();
|
||||||
|
|
||||||
|
// Build model
|
||||||
|
let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
|
||||||
|
Scale::T => (
|
||||||
|
Options::sam2_1_tiny_encoder(),
|
||||||
|
Options::sam2_1_tiny_decoder(),
|
||||||
|
),
|
||||||
|
Scale::S => (
|
||||||
|
Options::sam2_1_small_encoder(),
|
||||||
|
Options::sam2_1_small_decoder(),
|
||||||
|
),
|
||||||
|
Scale::B => (
|
||||||
|
Options::sam2_1_base_plus_encoder(),
|
||||||
|
Options::sam2_1_base_plus_decoder(),
|
||||||
|
),
|
||||||
|
Scale::L => (
|
||||||
|
Options::sam2_1_large_encoder(),
|
||||||
|
Options::sam2_1_large_decoder(),
|
||||||
|
),
|
||||||
|
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
|
||||||
|
};
|
||||||
|
|
||||||
|
let options_encoder = options_encoder
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?;
|
||||||
|
let options_decoder = options_decoder
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?;
|
||||||
|
let mut model = SAM2::new(options_encoder, options_decoder)?;
|
||||||
|
|
||||||
|
// Load image
|
||||||
|
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
|
||||||
|
|
||||||
|
// Prompt
|
||||||
|
let prompts = vec![SamPrompt::default()
|
||||||
|
// // # demo: point + point
|
||||||
|
// .with_positive_point(500., 375.) // mid window
|
||||||
|
// .with_positive_point(1125., 625.), // car door
|
||||||
|
// // # demo: bbox
|
||||||
|
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||||
|
// // # demo: bbox + negative point
|
||||||
|
// .with_xyxy(425., 600., 700., 875.) // left wheel
|
||||||
|
// .with_negative_point(575., 750.), // tire
|
||||||
|
// # demo: multiple objects with boxes
|
||||||
|
.with_xyxy(75., 275., 1725., 850.)
|
||||||
|
.with_xyxy(425., 600., 700., 875.)
|
||||||
|
.with_xyxy(1375., 550., 1650., 800.)
|
||||||
|
.with_xyxy(1240., 675., 1400., 750.)];
|
||||||
|
|
||||||
|
// Run & Annotate
|
||||||
|
let ys = model.forward(&xs, &prompts)?;
|
||||||
|
|
||||||
|
// annotate
|
||||||
|
let annotator = Annotator::default()
|
||||||
|
.with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true));
|
||||||
|
|
||||||
|
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||||
|
annotator.annotate(x, y)?.save(format!(
|
||||||
|
"{}.jpg",
|
||||||
|
usls::Dir::Current
|
||||||
|
.base_dir_with_subs(&["runs", model.spec()])?
|
||||||
|
.join(usls::timestamp(None))
|
||||||
|
.display(),
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,73 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use usls::{
|
|
||||||
models::{SamPrompt, SAM, YOLO},
|
|
||||||
Annotator, DataLoader, Options, Scale, Style,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(argh::FromArgs)]
|
|
||||||
/// Example
|
|
||||||
struct Args {
|
|
||||||
/// device
|
|
||||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
|
||||||
device: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
|
||||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
|
||||||
.init();
|
|
||||||
|
|
||||||
let args: Args = argh::from_env();
|
|
||||||
|
|
||||||
// build SAM
|
|
||||||
let (options_encoder, options_decoder) = (
|
|
||||||
Options::mobile_sam_tiny_encoder().commit()?,
|
|
||||||
Options::mobile_sam_tiny_decoder().commit()?,
|
|
||||||
);
|
|
||||||
let mut sam = SAM::new(options_encoder, options_decoder)?;
|
|
||||||
|
|
||||||
// build YOLOv8
|
|
||||||
let options_yolo = Options::yolo_detect()
|
|
||||||
.with_model_scale(Scale::N)
|
|
||||||
.with_model_version(8.into())
|
|
||||||
.with_model_device(args.device.as_str().try_into()?)
|
|
||||||
.commit()?;
|
|
||||||
let mut yolo = YOLO::new(options_yolo)?;
|
|
||||||
|
|
||||||
// load one image
|
|
||||||
let xs = DataLoader::try_read_n(&["images/dog.jpg"])?;
|
|
||||||
|
|
||||||
// build annotator
|
|
||||||
let annotator = Annotator::default().with_hbb_style(Style::hbb().with_draw_fill(true));
|
|
||||||
|
|
||||||
// run & annotate
|
|
||||||
let ys_det = yolo.forward(&xs)?;
|
|
||||||
for y_det in ys_det.iter() {
|
|
||||||
if let Some(hbbs) = y_det.hbbs() {
|
|
||||||
for hbb in hbbs {
|
|
||||||
let ys_sam = sam.forward(
|
|
||||||
&xs,
|
|
||||||
&[SamPrompt::default().with_bbox(
|
|
||||||
hbb.xmin(),
|
|
||||||
hbb.ymin(),
|
|
||||||
hbb.xmax(),
|
|
||||||
hbb.ymax(),
|
|
||||||
)],
|
|
||||||
)?;
|
|
||||||
// annotator.annotate(&xs, &ys_sam);
|
|
||||||
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
|
||||||
annotator.annotate(x, y)?.save(format!(
|
|
||||||
"{}.jpg",
|
|
||||||
usls::Dir::Current
|
|
||||||
.base_dir_with_subs(&["runs", "YOLO-SAM"])?
|
|
||||||
.join(usls::timestamp(None))
|
|
||||||
.display(),
|
|
||||||
))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
79
examples/yolo-sam2/main.rs
Normal file
79
examples/yolo-sam2/main.rs
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use usls::{
|
||||||
|
models::{SamPrompt, SAM2, YOLO},
|
||||||
|
Annotator, DataLoader, Options, Scale, Style,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(argh::FromArgs)]
|
||||||
|
/// Example
|
||||||
|
struct Args {
|
||||||
|
/// device
|
||||||
|
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||||
|
device: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let args: Args = argh::from_env();
|
||||||
|
|
||||||
|
// build SAM
|
||||||
|
let (options_encoder, options_decoder) = (
|
||||||
|
Options::sam2_1_tiny_encoder().commit()?,
|
||||||
|
Options::sam2_1_tiny_decoder().commit()?,
|
||||||
|
);
|
||||||
|
let mut sam = SAM2::new(options_encoder, options_decoder)?;
|
||||||
|
|
||||||
|
// build YOLOv8
|
||||||
|
let options_yolo = Options::yolo_detect()
|
||||||
|
.with_model_scale(Scale::N)
|
||||||
|
.with_model_version(8.into())
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?;
|
||||||
|
let mut yolo = YOLO::new(options_yolo)?;
|
||||||
|
|
||||||
|
// load one image
|
||||||
|
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
|
||||||
|
|
||||||
|
// build annotator
|
||||||
|
let annotator = Annotator::default()
|
||||||
|
.with_polygon_style(
|
||||||
|
Style::polygon()
|
||||||
|
.with_visible(true)
|
||||||
|
.with_text_visible(true)
|
||||||
|
.show_id(true)
|
||||||
|
.show_name(true),
|
||||||
|
)
|
||||||
|
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
|
||||||
|
|
||||||
|
// run & annotate
|
||||||
|
let ys_det = yolo.forward(&xs)?;
|
||||||
|
for y_det in ys_det.iter() {
|
||||||
|
if let Some(hbbs) = y_det.hbbs() {
|
||||||
|
// collect hhbs
|
||||||
|
let mut prompt = SamPrompt::default();
|
||||||
|
for hbb in hbbs {
|
||||||
|
prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
|
||||||
|
}
|
||||||
|
|
||||||
|
// sam2 infer
|
||||||
|
let ys_sam = sam.forward(&xs, &[prompt])?;
|
||||||
|
|
||||||
|
// annotate
|
||||||
|
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
||||||
|
annotator.annotate(x, y)?.save(format!(
|
||||||
|
"{}.jpg",
|
||||||
|
usls::Dir::Current
|
||||||
|
.base_dir_with_subs(&["runs", "YOLO-SAM2"])?
|
||||||
|
.join(usls::timestamp(None))
|
||||||
|
.display(),
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -24,6 +24,7 @@ mod rfdetr;
|
|||||||
mod rtdetr;
|
mod rtdetr;
|
||||||
mod rtmo;
|
mod rtmo;
|
||||||
mod sam;
|
mod sam;
|
||||||
|
mod sam2;
|
||||||
mod sapiens;
|
mod sapiens;
|
||||||
mod slanet;
|
mod slanet;
|
||||||
mod smolvlm;
|
mod smolvlm;
|
||||||
@ -49,6 +50,7 @@ pub use rfdetr::*;
|
|||||||
pub use rtdetr::*;
|
pub use rtdetr::*;
|
||||||
pub use rtmo::*;
|
pub use rtmo::*;
|
||||||
pub use sam::*;
|
pub use sam::*;
|
||||||
|
pub use sam2::*;
|
||||||
pub use sapiens::*;
|
pub use sapiens::*;
|
||||||
pub use slanet::*;
|
pub use slanet::*;
|
||||||
pub use smolvlm::*;
|
pub use smolvlm::*;
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
use aksr::Builder;
|
use aksr::Builder;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use ndarray::{s, Array, Axis};
|
use ndarray::{s, Axis};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, Ts, Xs, X, Y,
|
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X,
|
||||||
|
Y,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum SamKind {
|
pub enum SamKind {
|
||||||
Sam,
|
Sam,
|
||||||
Sam2,
|
Sam2, // 2.0
|
||||||
MobileSam,
|
MobileSam,
|
||||||
SamHq,
|
SamHq,
|
||||||
EdgeSam,
|
EdgeSam,
|
||||||
@ -31,54 +32,6 @@ impl TryFrom<&str> for SamKind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone)]
|
|
||||||
pub struct SamPrompt {
|
|
||||||
points: Vec<f32>,
|
|
||||||
labels: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SamPrompt {
|
|
||||||
pub fn everything() -> Self {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_postive_point(mut self, x: f32, y: f32) -> Self {
|
|
||||||
self.points.extend_from_slice(&[x, y]);
|
|
||||||
self.labels.push(1.);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
|
|
||||||
self.points.extend_from_slice(&[x, y]);
|
|
||||||
self.labels.push(0.);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self {
|
|
||||||
self.points.extend_from_slice(&[x, y, x2, y2]);
|
|
||||||
self.labels.extend_from_slice(&[2., 3.]);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn point_coords(&self, r: f32) -> Result<X> {
|
|
||||||
let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())?
|
|
||||||
.into_dyn()
|
|
||||||
.into_owned();
|
|
||||||
Ok(X::from(point_coords * r))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn point_labels(&self) -> Result<X> {
|
|
||||||
let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())?
|
|
||||||
.into_dyn()
|
|
||||||
.into_owned();
|
|
||||||
Ok(X::from(point_labels))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn num_points(&self) -> usize {
|
|
||||||
self.points.len() / 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
pub struct SAM {
|
pub struct SAM {
|
||||||
encoder: Engine,
|
encoder: Engine,
|
||||||
@ -167,14 +120,28 @@ impl SAM {
|
|||||||
);
|
);
|
||||||
let ratio = self.processor.images_transform_info[idx].height_scale;
|
let ratio = self.processor.images_transform_info[idx].height_scale;
|
||||||
|
|
||||||
|
let (mut point_coords, mut point_labels) = (
|
||||||
|
prompts[idx].point_coords(ratio)?,
|
||||||
|
prompts[idx].point_labels()?,
|
||||||
|
);
|
||||||
|
|
||||||
|
if point_coords.shape()[0] != 1 {
|
||||||
|
point_coords = X::from(point_coords.slice(s![-1, .., ..]).to_owned().into_dyn())
|
||||||
|
.insert_axis(0)?;
|
||||||
|
}
|
||||||
|
if point_labels.shape()[0] != 1 {
|
||||||
|
point_labels = X::from(point_labels.slice(s![-1, ..,]).to_owned().into_dyn())
|
||||||
|
.insert_axis(0)?;
|
||||||
|
}
|
||||||
|
|
||||||
let args = match self.kind {
|
let args = match self.kind {
|
||||||
SamKind::Sam | SamKind::MobileSam => {
|
SamKind::Sam | SamKind::MobileSam => {
|
||||||
vec![
|
vec![
|
||||||
X::from(image_embedding.into_dyn().into_owned())
|
X::from(image_embedding.into_dyn().into_owned())
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?, // image_embedding
|
.repeat(0, self.batch)?, // image_embedding
|
||||||
prompts[idx].point_coords(ratio)?, // point_coords
|
point_coords,
|
||||||
prompts[idx].point_labels()?, // point_labels
|
point_labels,
|
||||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
|
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
|
||||||
X::zeros(&[1]), // has_mask_input
|
X::zeros(&[1]), // has_mask_input
|
||||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
||||||
@ -189,8 +156,8 @@ impl SAM {
|
|||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?, // intern_embedding
|
.repeat(0, self.batch)?, // intern_embedding
|
||||||
prompts[idx].point_coords(ratio)?, // point_coords
|
point_coords,
|
||||||
prompts[idx].point_labels()?, // point_labels
|
point_labels,
|
||||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
||||||
X::zeros(&[1]), // has_mask_input
|
X::zeros(&[1]), // has_mask_input
|
||||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
||||||
@ -201,8 +168,8 @@ impl SAM {
|
|||||||
X::from(image_embedding.into_dyn().into_owned())
|
X::from(image_embedding.into_dyn().into_owned())
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?,
|
.repeat(0, self.batch)?,
|
||||||
prompts[idx].point_coords(ratio)?,
|
point_coords,
|
||||||
prompts[idx].point_labels()?,
|
point_labels,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
SamKind::Sam2 => {
|
SamKind::Sam2 => {
|
||||||
@ -228,11 +195,11 @@ impl SAM {
|
|||||||
)
|
)
|
||||||
.insert_axis(0)?
|
.insert_axis(0)?
|
||||||
.repeat(0, self.batch)?,
|
.repeat(0, self.batch)?,
|
||||||
prompts[idx].point_coords(ratio)?,
|
point_coords,
|
||||||
prompts[idx].point_labels()?,
|
point_labels,
|
||||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
|
||||||
X::zeros(&[1]), // has_mask_input
|
X::zeros(&[1]),
|
||||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
X::from(vec![image_height as _, image_width as _]),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2,3 +2,88 @@ mod config;
|
|||||||
mod r#impl;
|
mod r#impl;
|
||||||
|
|
||||||
pub use r#impl::*;
|
pub use r#impl::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct SamPrompt {
|
||||||
|
pub coords: Vec<Vec<[f32; 2]>>,
|
||||||
|
pub labels: Vec<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SamPrompt {
|
||||||
|
pub fn point_coords(&self, ratio: f32) -> anyhow::Result<crate::X> {
|
||||||
|
// [num_labels,num_points,2]
|
||||||
|
let num_labels = self.coords.len();
|
||||||
|
let num_points = if num_labels > 0 {
|
||||||
|
self.coords[0].len()
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let flat: Vec<f32> = self
|
||||||
|
.coords
|
||||||
|
.iter()
|
||||||
|
.flat_map(|v| v.iter().flat_map(|&[x, y]| [x, y]))
|
||||||
|
.collect();
|
||||||
|
let y = ndarray::Array3::from_shape_vec((num_labels, num_points, 2), flat)?.into_dyn();
|
||||||
|
|
||||||
|
Ok((y * ratio).into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn point_labels(&self) -> anyhow::Result<crate::X> {
|
||||||
|
// [num_labels,num_points]
|
||||||
|
let num_labels = self.labels.len();
|
||||||
|
let num_points = if num_labels > 0 {
|
||||||
|
self.labels[0].len()
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let flat: Vec<f32> = self.labels.iter().flat_map(|v| v.iter().copied()).collect();
|
||||||
|
let y = ndarray::Array2::from_shape_vec((num_labels, num_points), flat)?.into_dyn();
|
||||||
|
Ok(y.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_xyxy(mut self, x1: f32, y1: f32, x2: f32, y2: f32) -> Self {
|
||||||
|
// TODO: if already has points, push_front coords
|
||||||
|
self.coords.push(vec![[x1, y1], [x2, y2]]);
|
||||||
|
self.labels.push(vec![2., 3.]);
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_positive_point(mut self, x: f32, y: f32) -> Self {
|
||||||
|
self = self.add_point(x, y, 1.);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
|
||||||
|
self = self.add_point(x, y, 0.);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_point(mut self, x: f32, y: f32, id: f32) -> Self {
|
||||||
|
if self.coords.is_empty() {
|
||||||
|
self.coords.push(vec![[x, y]]);
|
||||||
|
self.labels.push(vec![id]);
|
||||||
|
} else {
|
||||||
|
if let Some(last) = self.coords.last_mut() {
|
||||||
|
last.extend_from_slice(&[[x, y]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(last) = self.labels.last_mut() {
|
||||||
|
last.extend_from_slice(&[id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_positive_point_object(mut self, x: f32, y: f32) -> Self {
|
||||||
|
self.coords.push(vec![[x, y]]);
|
||||||
|
self.labels.push(vec![1.]);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_negative_point_object(mut self, x: f32, y: f32) -> Self {
|
||||||
|
self.coords.push(vec![[x, y]]);
|
||||||
|
self.labels.push(vec![0.]);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
10
src/models/sam2/README.md
Normal file
10
src/models/sam2/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Segment Anything Model
|
||||||
|
|
||||||
|
## Official Repository
|
||||||
|
|
||||||
|
The official repository can be found on [sam2](https://github.com/facebookresearch/sam2)
|
||||||
|
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Refer to the [example](../../../examples/sam2)
|
50
src/models/sam2/config.rs
Normal file
50
src/models/sam2/config.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use crate::Options;
|
||||||
|
|
||||||
|
/// Model configuration for `SAM2.1`
|
||||||
|
impl Options {
|
||||||
|
pub fn sam2_encoder() -> Self {
|
||||||
|
Self::sam()
|
||||||
|
.with_model_ixx(0, 2, 1024.into())
|
||||||
|
.with_model_ixx(0, 3, 1024.into())
|
||||||
|
.with_resize_mode(crate::ResizeMode::FitAdaptive)
|
||||||
|
.with_resize_filter("Bilinear")
|
||||||
|
.with_image_mean(&[0.485, 0.456, 0.406])
|
||||||
|
.with_image_std(&[0.229, 0.224, 0.225])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_decoder() -> Self {
|
||||||
|
Self::sam()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_tiny_encoder() -> Self {
|
||||||
|
Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_tiny_decoder() -> Self {
|
||||||
|
Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_small_encoder() -> Self {
|
||||||
|
Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_small_decoder() -> Self {
|
||||||
|
Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_base_plus_encoder() -> Self {
|
||||||
|
Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_base_plus_decoder() -> Self {
|
||||||
|
Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_large_encoder() -> Self {
|
||||||
|
Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sam2_1_large_decoder() -> Self {
|
||||||
|
Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx")
|
||||||
|
}
|
||||||
|
}
|
164
src/models/sam2/impl.rs
Normal file
164
src/models/sam2/impl.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
use aksr::Builder;
|
||||||
|
use anyhow::Result;
|
||||||
|
use ndarray::{s, Axis};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Builder, Debug)]
|
||||||
|
pub struct SAM2 {
|
||||||
|
encoder: Engine,
|
||||||
|
decoder: Engine,
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
batch: usize,
|
||||||
|
processor: Processor,
|
||||||
|
conf: DynConf,
|
||||||
|
ts: Ts,
|
||||||
|
spec: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SAM2 {
|
||||||
|
pub fn new(options_encoder: Options, options_decoder: Options) -> Result<Self> {
|
||||||
|
let encoder = options_encoder.to_engine()?;
|
||||||
|
let decoder = options_decoder.to_engine()?;
|
||||||
|
let (batch, height, width) = (
|
||||||
|
encoder.batch().opt(),
|
||||||
|
encoder.try_height().unwrap_or(&1024.into()).opt(),
|
||||||
|
encoder.try_width().unwrap_or(&1024.into()).opt(),
|
||||||
|
);
|
||||||
|
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
|
||||||
|
let spec = encoder.spec().to_owned();
|
||||||
|
let processor = options_encoder
|
||||||
|
.to_processor()?
|
||||||
|
.with_image_width(width as _)
|
||||||
|
.with_image_height(height as _);
|
||||||
|
let conf = DynConf::new(options_encoder.class_confs(), 1);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
conf,
|
||||||
|
batch,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
ts,
|
||||||
|
processor,
|
||||||
|
spec,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
|
||||||
|
let ys = elapsed!("encode", self.ts, { self.encode(xs)? });
|
||||||
|
let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? });
|
||||||
|
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(&mut self, xs: &[Image]) -> Result<Xs> {
|
||||||
|
let xs_ = self.processor.process_images(xs)?;
|
||||||
|
self.encoder.run(Xs::from(xs_))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result<Vec<Y>> {
|
||||||
|
let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]);
|
||||||
|
|
||||||
|
let mut ys: Vec<Y> = Vec::new();
|
||||||
|
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
|
||||||
|
let (image_height, image_width) = (
|
||||||
|
self.processor.images_transform_info[idx].height_src,
|
||||||
|
self.processor.images_transform_info[idx].width_src,
|
||||||
|
);
|
||||||
|
let ratio = self.processor.images_transform_info[idx].height_scale;
|
||||||
|
|
||||||
|
let ys_ = self.decoder.run(Xs::from(vec![
|
||||||
|
X::from(image_embedding.into_dyn().into_owned())
|
||||||
|
.insert_axis(0)?
|
||||||
|
.repeat(0, self.batch)?,
|
||||||
|
X::from(
|
||||||
|
high_res_features_0
|
||||||
|
.slice(s![idx, .., .., ..])
|
||||||
|
.into_dyn()
|
||||||
|
.into_owned(),
|
||||||
|
)
|
||||||
|
.insert_axis(0)?
|
||||||
|
.repeat(0, self.batch)?,
|
||||||
|
X::from(
|
||||||
|
high_res_features_1
|
||||||
|
.slice(s![idx, .., .., ..])
|
||||||
|
.into_dyn()
|
||||||
|
.into_owned(),
|
||||||
|
)
|
||||||
|
.insert_axis(0)?
|
||||||
|
.repeat(0, self.batch)?,
|
||||||
|
prompts[idx].point_coords(ratio)?,
|
||||||
|
prompts[idx].point_labels()?,
|
||||||
|
// TODO
|
||||||
|
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
|
||||||
|
X::zeros(&[1]),
|
||||||
|
X::from(vec![self.width as _, self.height as _]),
|
||||||
|
]))?;
|
||||||
|
|
||||||
|
let mut y_masks: Vec<Mask> = Vec::new();
|
||||||
|
|
||||||
|
// masks & confs
|
||||||
|
let (masks, confs) = (&ys_[0], &ys_[1]);
|
||||||
|
|
||||||
|
for (id, (mask, iou)) in masks
|
||||||
|
.axis_iter(Axis(0))
|
||||||
|
.zip(confs.axis_iter(Axis(0)))
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
let (i, conf) = match iou
|
||||||
|
.to_owned()
|
||||||
|
.into_raw_vec_and_offset()
|
||||||
|
.0
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|a, b| a.1.total_cmp(&b.1))
|
||||||
|
{
|
||||||
|
Some((i, c)) => (i, c),
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if conf < self.conf[0] {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mask = mask.slice(s![i, .., ..]);
|
||||||
|
|
||||||
|
let (h, w) = mask.dim();
|
||||||
|
let luma = Ops::resize_lumaf32_u8(
|
||||||
|
&mask.into_owned().into_raw_vec_and_offset().0,
|
||||||
|
w as _,
|
||||||
|
h as _,
|
||||||
|
image_width as _,
|
||||||
|
image_height as _,
|
||||||
|
true,
|
||||||
|
"Bilinear",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// contours
|
||||||
|
let mask = Mask::new(&luma, image_width, image_height)?.with_id(id);
|
||||||
|
y_masks.push(mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut y = Y::default();
|
||||||
|
if !y_masks.is_empty() {
|
||||||
|
y = y.with_masks(&y_masks);
|
||||||
|
}
|
||||||
|
|
||||||
|
ys.push(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn width_low_res(&self) -> usize {
|
||||||
|
self.width / 4
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn height_low_res(&self) -> usize {
|
||||||
|
self.height / 4
|
||||||
|
}
|
||||||
|
}
|
4
src/models/sam2/mod.rs
Normal file
4
src/models/sam2/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mod config;
|
||||||
|
mod r#impl;
|
||||||
|
|
||||||
|
pub use r#impl::*;
|
@ -27,6 +27,14 @@ pub struct Options {
|
|||||||
pub trt_fp16: bool,
|
pub trt_fp16: bool,
|
||||||
pub profile: bool,
|
pub profile: bool,
|
||||||
|
|
||||||
|
// models
|
||||||
|
pub model_encoder_file: Option<String>,
|
||||||
|
pub model_decoder_file: Option<String>,
|
||||||
|
pub visual_encoder_file: Option<String>,
|
||||||
|
pub visual_decoder_file: Option<String>,
|
||||||
|
pub textual_encoder_file: Option<String>,
|
||||||
|
pub textual_decoder_file: Option<String>,
|
||||||
|
|
||||||
// Processor configs
|
// Processor configs
|
||||||
#[args(except(setter))]
|
#[args(except(setter))]
|
||||||
pub image_width: u32,
|
pub image_width: u32,
|
||||||
@ -113,8 +121,8 @@ pub struct Options {
|
|||||||
pub binary_thresh: Option<f32>,
|
pub binary_thresh: Option<f32>,
|
||||||
|
|
||||||
// For SAM
|
// For SAM
|
||||||
pub sam_kind: Option<SamKind>,
|
pub sam_kind: Option<SamKind>, // TODO: remove
|
||||||
pub low_res_mask: Option<bool>,
|
pub low_res_mask: Option<bool>, // TODO: remove
|
||||||
|
|
||||||
// Others
|
// Others
|
||||||
pub ort_graph_opt_level: Option<u8>,
|
pub ort_graph_opt_level: Option<u8>,
|
||||||
@ -203,6 +211,12 @@ impl Default for Options {
|
|||||||
topk_2: None,
|
topk_2: None,
|
||||||
topk_3: None,
|
topk_3: None,
|
||||||
ort_graph_opt_level: None,
|
ort_graph_opt_level: None,
|
||||||
|
model_encoder_file: None,
|
||||||
|
model_decoder_file: None,
|
||||||
|
visual_encoder_file: None,
|
||||||
|
visual_decoder_file: None,
|
||||||
|
textual_encoder_file: None,
|
||||||
|
textual_decoder_file: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user