mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Add SAM2.1 models and support batched prompt inputs (#89)
This commit is contained in:
@ -16,7 +16,7 @@ struct Args {
|
||||
scale: String,
|
||||
|
||||
/// SAM kind
|
||||
#[argh(option, default = "String::from(\"sam\")")]
|
||||
#[argh(option, default = "String::from(\"samhq\")")]
|
||||
kind: String,
|
||||
}
|
||||
|
||||
@ -69,9 +69,19 @@ fn main() -> Result<()> {
|
||||
// Prompt
|
||||
let prompts = vec![
|
||||
SamPrompt::default()
|
||||
// .with_postive_point(500., 375.), // postive point
|
||||
// .with_negative_point(774., 366.), // negative point
|
||||
.with_bbox(215., 297., 643., 459.), // bbox
|
||||
// // # demo: point + point
|
||||
// .with_positive_point(500., 375.) // mid window
|
||||
// .with_positive_point(1125., 625.), // car door
|
||||
// // # demo: bbox
|
||||
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||
// // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored.
|
||||
// .with_xyxy(75., 275., 1725., 850.)
|
||||
// .with_xyxy(425., 600., 700., 875.)
|
||||
// .with_xyxy(1240., 675., 1400., 750.)
|
||||
// .with_xyxy(1375., 550., 1650., 800.)
|
||||
// # demo: bbox + negative point
|
||||
.with_xyxy(425., 600., 700., 875.) // left wheel
|
||||
.with_negative_point(575., 750.), // tire
|
||||
];
|
||||
|
||||
// Run & Annotate
|
||||
|
6
examples/sam2/README.md
Normal file
6
examples/sam2/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
## Quick Start
|
||||
|
||||
```Shell
|
||||
|
||||
cargo run -r -F cuda --example sam -- --device cuda --scale t
|
||||
```
|
93
examples/sam2/main.rs
Normal file
93
examples/sam2/main.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM2},
|
||||
Annotator, DataLoader, Options, Scale,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
|
||||
/// scale
|
||||
#[argh(option, default = "String::from(\"t\")")]
|
||||
scale: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// Build model
|
||||
let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? {
|
||||
Scale::T => (
|
||||
Options::sam2_1_tiny_encoder(),
|
||||
Options::sam2_1_tiny_decoder(),
|
||||
),
|
||||
Scale::S => (
|
||||
Options::sam2_1_small_encoder(),
|
||||
Options::sam2_1_small_decoder(),
|
||||
),
|
||||
Scale::B => (
|
||||
Options::sam2_1_base_plus_encoder(),
|
||||
Options::sam2_1_base_plus_decoder(),
|
||||
),
|
||||
Scale::L => (
|
||||
Options::sam2_1_large_encoder(),
|
||||
Options::sam2_1_large_decoder(),
|
||||
),
|
||||
_ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale),
|
||||
};
|
||||
|
||||
let options_encoder = options_encoder
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let options_decoder = options_decoder
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut model = SAM2::new(options_encoder, options_decoder)?;
|
||||
|
||||
// Load image
|
||||
let xs = DataLoader::try_read_n(&["images/truck.jpg"])?;
|
||||
|
||||
// Prompt
|
||||
let prompts = vec![SamPrompt::default()
|
||||
// // # demo: point + point
|
||||
// .with_positive_point(500., 375.) // mid window
|
||||
// .with_positive_point(1125., 625.), // car door
|
||||
// // # demo: bbox
|
||||
// .with_xyxy(425., 600., 700., 875.), // left wheel
|
||||
// // # demo: bbox + negative point
|
||||
// .with_xyxy(425., 600., 700., 875.) // left wheel
|
||||
// .with_negative_point(575., 750.), // tire
|
||||
// # demo: multiple objects with boxes
|
||||
.with_xyxy(75., 275., 1725., 850.)
|
||||
.with_xyxy(425., 600., 700., 875.)
|
||||
.with_xyxy(1375., 550., 1650., 800.)
|
||||
.with_xyxy(1240., 675., 1400., 750.)];
|
||||
|
||||
// Run & Annotate
|
||||
let ys = model.forward(&xs, &prompts)?;
|
||||
|
||||
// annotate
|
||||
let annotator = Annotator::default()
|
||||
.with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true));
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", model.spec()])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM, YOLO},
|
||||
Annotator, DataLoader, Options, Scale, Style,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build SAM
|
||||
let (options_encoder, options_decoder) = (
|
||||
Options::mobile_sam_tiny_encoder().commit()?,
|
||||
Options::mobile_sam_tiny_decoder().commit()?,
|
||||
);
|
||||
let mut sam = SAM::new(options_encoder, options_decoder)?;
|
||||
|
||||
// build YOLOv8
|
||||
let options_yolo = Options::yolo_detect()
|
||||
.with_model_scale(Scale::N)
|
||||
.with_model_version(8.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut yolo = YOLO::new(options_yolo)?;
|
||||
|
||||
// load one image
|
||||
let xs = DataLoader::try_read_n(&["images/dog.jpg"])?;
|
||||
|
||||
// build annotator
|
||||
let annotator = Annotator::default().with_hbb_style(Style::hbb().with_draw_fill(true));
|
||||
|
||||
// run & annotate
|
||||
let ys_det = yolo.forward(&xs)?;
|
||||
for y_det in ys_det.iter() {
|
||||
if let Some(hbbs) = y_det.hbbs() {
|
||||
for hbb in hbbs {
|
||||
let ys_sam = sam.forward(
|
||||
&xs,
|
||||
&[SamPrompt::default().with_bbox(
|
||||
hbb.xmin(),
|
||||
hbb.ymin(),
|
||||
hbb.xmax(),
|
||||
hbb.ymax(),
|
||||
)],
|
||||
)?;
|
||||
// annotator.annotate(&xs, &ys_sam);
|
||||
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", "YOLO-SAM"])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
79
examples/yolo-sam2/main.rs
Normal file
79
examples/yolo-sam2/main.rs
Normal file
@ -0,0 +1,79 @@
|
||||
use anyhow::Result;
|
||||
use usls::{
|
||||
models::{SamPrompt, SAM2, YOLO},
|
||||
Annotator, DataLoader, Options, Scale, Style,
|
||||
};
|
||||
|
||||
#[derive(argh::FromArgs)]
|
||||
/// Example
|
||||
struct Args {
|
||||
/// device
|
||||
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||
device: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||
.init();
|
||||
|
||||
let args: Args = argh::from_env();
|
||||
|
||||
// build SAM
|
||||
let (options_encoder, options_decoder) = (
|
||||
Options::sam2_1_tiny_encoder().commit()?,
|
||||
Options::sam2_1_tiny_decoder().commit()?,
|
||||
);
|
||||
let mut sam = SAM2::new(options_encoder, options_decoder)?;
|
||||
|
||||
// build YOLOv8
|
||||
let options_yolo = Options::yolo_detect()
|
||||
.with_model_scale(Scale::N)
|
||||
.with_model_version(8.into())
|
||||
.with_model_device(args.device.as_str().try_into()?)
|
||||
.commit()?;
|
||||
let mut yolo = YOLO::new(options_yolo)?;
|
||||
|
||||
// load one image
|
||||
let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?;
|
||||
|
||||
// build annotator
|
||||
let annotator = Annotator::default()
|
||||
.with_polygon_style(
|
||||
Style::polygon()
|
||||
.with_visible(true)
|
||||
.with_text_visible(true)
|
||||
.show_id(true)
|
||||
.show_name(true),
|
||||
)
|
||||
.with_mask_style(Style::mask().with_draw_mask_polygon_largest(true));
|
||||
|
||||
// run & annotate
|
||||
let ys_det = yolo.forward(&xs)?;
|
||||
for y_det in ys_det.iter() {
|
||||
if let Some(hbbs) = y_det.hbbs() {
|
||||
// collect hhbs
|
||||
let mut prompt = SamPrompt::default();
|
||||
for hbb in hbbs {
|
||||
prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
|
||||
}
|
||||
|
||||
// sam2 infer
|
||||
let ys_sam = sam.forward(&xs, &[prompt])?;
|
||||
|
||||
// annotate
|
||||
for (x, y) in xs.iter().zip(ys_sam.iter()) {
|
||||
annotator.annotate(x, y)?.save(format!(
|
||||
"{}.jpg",
|
||||
usls::Dir::Current
|
||||
.base_dir_with_subs(&["runs", "YOLO-SAM2"])?
|
||||
.join(usls::timestamp(None))
|
||||
.display(),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -24,6 +24,7 @@ mod rfdetr;
|
||||
mod rtdetr;
|
||||
mod rtmo;
|
||||
mod sam;
|
||||
mod sam2;
|
||||
mod sapiens;
|
||||
mod slanet;
|
||||
mod smolvlm;
|
||||
@ -49,6 +50,7 @@ pub use rfdetr::*;
|
||||
pub use rtdetr::*;
|
||||
pub use rtmo::*;
|
||||
pub use sam::*;
|
||||
pub use sam2::*;
|
||||
pub use sapiens::*;
|
||||
pub use slanet::*;
|
||||
pub use smolvlm::*;
|
||||
|
@ -1,16 +1,17 @@
|
||||
use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use ndarray::{s, Array, Axis};
|
||||
use ndarray::{s, Axis};
|
||||
use rand::prelude::*;
|
||||
|
||||
use crate::{
|
||||
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, Ts, Xs, X, Y,
|
||||
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X,
|
||||
Y,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SamKind {
|
||||
Sam,
|
||||
Sam2,
|
||||
Sam2, // 2.0
|
||||
MobileSam,
|
||||
SamHq,
|
||||
EdgeSam,
|
||||
@ -31,54 +32,6 @@ impl TryFrom<&str> for SamKind {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct SamPrompt {
|
||||
points: Vec<f32>,
|
||||
labels: Vec<f32>,
|
||||
}
|
||||
|
||||
impl SamPrompt {
|
||||
pub fn everything() -> Self {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn with_postive_point(mut self, x: f32, y: f32) -> Self {
|
||||
self.points.extend_from_slice(&[x, y]);
|
||||
self.labels.push(1.);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
|
||||
self.points.extend_from_slice(&[x, y]);
|
||||
self.labels.push(0.);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self {
|
||||
self.points.extend_from_slice(&[x, y, x2, y2]);
|
||||
self.labels.extend_from_slice(&[2., 3.]);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn point_coords(&self, r: f32) -> Result<X> {
|
||||
let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())?
|
||||
.into_dyn()
|
||||
.into_owned();
|
||||
Ok(X::from(point_coords * r))
|
||||
}
|
||||
|
||||
pub fn point_labels(&self) -> Result<X> {
|
||||
let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())?
|
||||
.into_dyn()
|
||||
.into_owned();
|
||||
Ok(X::from(point_labels))
|
||||
}
|
||||
|
||||
pub fn num_points(&self) -> usize {
|
||||
self.points.len() / 2
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct SAM {
|
||||
encoder: Engine,
|
||||
@ -167,14 +120,28 @@ impl SAM {
|
||||
);
|
||||
let ratio = self.processor.images_transform_info[idx].height_scale;
|
||||
|
||||
let (mut point_coords, mut point_labels) = (
|
||||
prompts[idx].point_coords(ratio)?,
|
||||
prompts[idx].point_labels()?,
|
||||
);
|
||||
|
||||
if point_coords.shape()[0] != 1 {
|
||||
point_coords = X::from(point_coords.slice(s![-1, .., ..]).to_owned().into_dyn())
|
||||
.insert_axis(0)?;
|
||||
}
|
||||
if point_labels.shape()[0] != 1 {
|
||||
point_labels = X::from(point_labels.slice(s![-1, ..,]).to_owned().into_dyn())
|
||||
.insert_axis(0)?;
|
||||
}
|
||||
|
||||
let args = match self.kind {
|
||||
SamKind::Sam | SamKind::MobileSam => {
|
||||
vec![
|
||||
X::from(image_embedding.into_dyn().into_owned())
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?, // image_embedding
|
||||
prompts[idx].point_coords(ratio)?, // point_coords
|
||||
prompts[idx].point_labels()?, // point_labels
|
||||
point_coords,
|
||||
point_labels,
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input,
|
||||
X::zeros(&[1]), // has_mask_input
|
||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
||||
@ -189,8 +156,8 @@ impl SAM {
|
||||
.insert_axis(0)?
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?, // intern_embedding
|
||||
prompts[idx].point_coords(ratio)?, // point_coords
|
||||
prompts[idx].point_labels()?, // point_labels
|
||||
point_coords,
|
||||
point_labels,
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
||||
X::zeros(&[1]), // has_mask_input
|
||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
||||
@ -201,8 +168,8 @@ impl SAM {
|
||||
X::from(image_embedding.into_dyn().into_owned())
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?,
|
||||
prompts[idx].point_coords(ratio)?,
|
||||
prompts[idx].point_labels()?,
|
||||
point_coords,
|
||||
point_labels,
|
||||
]
|
||||
}
|
||||
SamKind::Sam2 => {
|
||||
@ -228,11 +195,11 @@ impl SAM {
|
||||
)
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?,
|
||||
prompts[idx].point_coords(ratio)?,
|
||||
prompts[idx].point_labels()?,
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
||||
X::zeros(&[1]), // has_mask_input
|
||||
X::from(vec![image_height as _, image_width as _]), // orig_im_size
|
||||
point_coords,
|
||||
point_labels,
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
|
||||
X::zeros(&[1]),
|
||||
X::from(vec![image_height as _, image_width as _]),
|
||||
]
|
||||
}
|
||||
};
|
||||
|
@ -2,3 +2,88 @@ mod config;
|
||||
mod r#impl;
|
||||
|
||||
pub use r#impl::*;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct SamPrompt {
|
||||
pub coords: Vec<Vec<[f32; 2]>>,
|
||||
pub labels: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl SamPrompt {
|
||||
pub fn point_coords(&self, ratio: f32) -> anyhow::Result<crate::X> {
|
||||
// [num_labels,num_points,2]
|
||||
let num_labels = self.coords.len();
|
||||
let num_points = if num_labels > 0 {
|
||||
self.coords[0].len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let flat: Vec<f32> = self
|
||||
.coords
|
||||
.iter()
|
||||
.flat_map(|v| v.iter().flat_map(|&[x, y]| [x, y]))
|
||||
.collect();
|
||||
let y = ndarray::Array3::from_shape_vec((num_labels, num_points, 2), flat)?.into_dyn();
|
||||
|
||||
Ok((y * ratio).into())
|
||||
}
|
||||
|
||||
pub fn point_labels(&self) -> anyhow::Result<crate::X> {
|
||||
// [num_labels,num_points]
|
||||
let num_labels = self.labels.len();
|
||||
let num_points = if num_labels > 0 {
|
||||
self.labels[0].len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let flat: Vec<f32> = self.labels.iter().flat_map(|v| v.iter().copied()).collect();
|
||||
let y = ndarray::Array2::from_shape_vec((num_labels, num_points), flat)?.into_dyn();
|
||||
Ok(y.into())
|
||||
}
|
||||
|
||||
pub fn with_xyxy(mut self, x1: f32, y1: f32, x2: f32, y2: f32) -> Self {
|
||||
// TODO: if already has points, push_front coords
|
||||
self.coords.push(vec![[x1, y1], [x2, y2]]);
|
||||
self.labels.push(vec![2., 3.]);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_positive_point(mut self, x: f32, y: f32) -> Self {
|
||||
self = self.add_point(x, y, 1.);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_negative_point(mut self, x: f32, y: f32) -> Self {
|
||||
self = self.add_point(x, y, 0.);
|
||||
self
|
||||
}
|
||||
|
||||
fn add_point(mut self, x: f32, y: f32, id: f32) -> Self {
|
||||
if self.coords.is_empty() {
|
||||
self.coords.push(vec![[x, y]]);
|
||||
self.labels.push(vec![id]);
|
||||
} else {
|
||||
if let Some(last) = self.coords.last_mut() {
|
||||
last.extend_from_slice(&[[x, y]]);
|
||||
}
|
||||
|
||||
if let Some(last) = self.labels.last_mut() {
|
||||
last.extend_from_slice(&[id]);
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_positive_point_object(mut self, x: f32, y: f32) -> Self {
|
||||
self.coords.push(vec![[x, y]]);
|
||||
self.labels.push(vec![1.]);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_negative_point_object(mut self, x: f32, y: f32) -> Self {
|
||||
self.coords.push(vec![[x, y]]);
|
||||
self.labels.push(vec![0.]);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
10
src/models/sam2/README.md
Normal file
10
src/models/sam2/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Segment Anything Model
|
||||
|
||||
## Official Repository
|
||||
|
||||
The official repository can be found on [sam2](https://github.com/facebookresearch/sam2)
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
Refer to the [example](../../../examples/sam2)
|
50
src/models/sam2/config.rs
Normal file
50
src/models/sam2/config.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use crate::Options;
|
||||
|
||||
/// Model configuration for `SAM2.1`
|
||||
impl Options {
|
||||
pub fn sam2_encoder() -> Self {
|
||||
Self::sam()
|
||||
.with_model_ixx(0, 2, 1024.into())
|
||||
.with_model_ixx(0, 3, 1024.into())
|
||||
.with_resize_mode(crate::ResizeMode::FitAdaptive)
|
||||
.with_resize_filter("Bilinear")
|
||||
.with_image_mean(&[0.485, 0.456, 0.406])
|
||||
.with_image_std(&[0.229, 0.224, 0.225])
|
||||
}
|
||||
|
||||
pub fn sam2_decoder() -> Self {
|
||||
Self::sam()
|
||||
}
|
||||
|
||||
pub fn sam2_1_tiny_encoder() -> Self {
|
||||
Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_tiny_decoder() -> Self {
|
||||
Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_small_encoder() -> Self {
|
||||
Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_small_decoder() -> Self {
|
||||
Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_base_plus_encoder() -> Self {
|
||||
Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_base_plus_decoder() -> Self {
|
||||
Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_large_encoder() -> Self {
|
||||
Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx")
|
||||
}
|
||||
|
||||
pub fn sam2_1_large_decoder() -> Self {
|
||||
Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx")
|
||||
}
|
||||
}
|
164
src/models/sam2/impl.rs
Normal file
164
src/models/sam2/impl.rs
Normal file
@ -0,0 +1,164 @@
|
||||
use aksr::Builder;
|
||||
use anyhow::Result;
|
||||
use ndarray::{s, Axis};
|
||||
|
||||
use crate::{
|
||||
elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y,
|
||||
};
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct SAM2 {
|
||||
encoder: Engine,
|
||||
decoder: Engine,
|
||||
height: usize,
|
||||
width: usize,
|
||||
batch: usize,
|
||||
processor: Processor,
|
||||
conf: DynConf,
|
||||
ts: Ts,
|
||||
spec: String,
|
||||
}
|
||||
|
||||
impl SAM2 {
|
||||
pub fn new(options_encoder: Options, options_decoder: Options) -> Result<Self> {
|
||||
let encoder = options_encoder.to_engine()?;
|
||||
let decoder = options_decoder.to_engine()?;
|
||||
let (batch, height, width) = (
|
||||
encoder.batch().opt(),
|
||||
encoder.try_height().unwrap_or(&1024.into()).opt(),
|
||||
encoder.try_width().unwrap_or(&1024.into()).opt(),
|
||||
);
|
||||
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
|
||||
let spec = encoder.spec().to_owned();
|
||||
let processor = options_encoder
|
||||
.to_processor()?
|
||||
.with_image_width(width as _)
|
||||
.with_image_height(height as _);
|
||||
let conf = DynConf::new(options_encoder.class_confs(), 1);
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
conf,
|
||||
batch,
|
||||
height,
|
||||
width,
|
||||
ts,
|
||||
processor,
|
||||
spec,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result<Vec<Y>> {
|
||||
let ys = elapsed!("encode", self.ts, { self.encode(xs)? });
|
||||
let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? });
|
||||
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn encode(&mut self, xs: &[Image]) -> Result<Xs> {
|
||||
let xs_ = self.processor.process_images(xs)?;
|
||||
self.encoder.run(Xs::from(xs_))
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result<Vec<Y>> {
|
||||
let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]);
|
||||
|
||||
let mut ys: Vec<Y> = Vec::new();
|
||||
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
|
||||
let (image_height, image_width) = (
|
||||
self.processor.images_transform_info[idx].height_src,
|
||||
self.processor.images_transform_info[idx].width_src,
|
||||
);
|
||||
let ratio = self.processor.images_transform_info[idx].height_scale;
|
||||
|
||||
let ys_ = self.decoder.run(Xs::from(vec![
|
||||
X::from(image_embedding.into_dyn().into_owned())
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?,
|
||||
X::from(
|
||||
high_res_features_0
|
||||
.slice(s![idx, .., .., ..])
|
||||
.into_dyn()
|
||||
.into_owned(),
|
||||
)
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?,
|
||||
X::from(
|
||||
high_res_features_1
|
||||
.slice(s![idx, .., .., ..])
|
||||
.into_dyn()
|
||||
.into_owned(),
|
||||
)
|
||||
.insert_axis(0)?
|
||||
.repeat(0, self.batch)?,
|
||||
prompts[idx].point_coords(ratio)?,
|
||||
prompts[idx].point_labels()?,
|
||||
// TODO
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]),
|
||||
X::zeros(&[1]),
|
||||
X::from(vec![self.width as _, self.height as _]),
|
||||
]))?;
|
||||
|
||||
let mut y_masks: Vec<Mask> = Vec::new();
|
||||
|
||||
// masks & confs
|
||||
let (masks, confs) = (&ys_[0], &ys_[1]);
|
||||
|
||||
for (id, (mask, iou)) in masks
|
||||
.axis_iter(Axis(0))
|
||||
.zip(confs.axis_iter(Axis(0)))
|
||||
.enumerate()
|
||||
{
|
||||
let (i, conf) = match iou
|
||||
.to_owned()
|
||||
.into_raw_vec_and_offset()
|
||||
.0
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.total_cmp(&b.1))
|
||||
{
|
||||
Some((i, c)) => (i, c),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if conf < self.conf[0] {
|
||||
continue;
|
||||
}
|
||||
let mask = mask.slice(s![i, .., ..]);
|
||||
|
||||
let (h, w) = mask.dim();
|
||||
let luma = Ops::resize_lumaf32_u8(
|
||||
&mask.into_owned().into_raw_vec_and_offset().0,
|
||||
w as _,
|
||||
h as _,
|
||||
image_width as _,
|
||||
image_height as _,
|
||||
true,
|
||||
"Bilinear",
|
||||
)?;
|
||||
|
||||
// contours
|
||||
let mask = Mask::new(&luma, image_width, image_height)?.with_id(id);
|
||||
y_masks.push(mask);
|
||||
}
|
||||
|
||||
let mut y = Y::default();
|
||||
if !y_masks.is_empty() {
|
||||
y = y.with_masks(&y_masks);
|
||||
}
|
||||
|
||||
ys.push(y);
|
||||
}
|
||||
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn width_low_res(&self) -> usize {
|
||||
self.width / 4
|
||||
}
|
||||
|
||||
pub fn height_low_res(&self) -> usize {
|
||||
self.height / 4
|
||||
}
|
||||
}
|
4
src/models/sam2/mod.rs
Normal file
4
src/models/sam2/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
mod config;
|
||||
mod r#impl;
|
||||
|
||||
pub use r#impl::*;
|
@ -27,6 +27,14 @@ pub struct Options {
|
||||
pub trt_fp16: bool,
|
||||
pub profile: bool,
|
||||
|
||||
// models
|
||||
pub model_encoder_file: Option<String>,
|
||||
pub model_decoder_file: Option<String>,
|
||||
pub visual_encoder_file: Option<String>,
|
||||
pub visual_decoder_file: Option<String>,
|
||||
pub textual_encoder_file: Option<String>,
|
||||
pub textual_decoder_file: Option<String>,
|
||||
|
||||
// Processor configs
|
||||
#[args(except(setter))]
|
||||
pub image_width: u32,
|
||||
@ -113,8 +121,8 @@ pub struct Options {
|
||||
pub binary_thresh: Option<f32>,
|
||||
|
||||
// For SAM
|
||||
pub sam_kind: Option<SamKind>,
|
||||
pub low_res_mask: Option<bool>,
|
||||
pub sam_kind: Option<SamKind>, // TODO: remove
|
||||
pub low_res_mask: Option<bool>, // TODO: remove
|
||||
|
||||
// Others
|
||||
pub ort_graph_opt_level: Option<u8>,
|
||||
@ -203,6 +211,12 @@ impl Default for Options {
|
||||
topk_2: None,
|
||||
topk_3: None,
|
||||
ort_graph_opt_level: None,
|
||||
model_encoder_file: None,
|
||||
model_decoder_file: None,
|
||||
visual_encoder_file: None,
|
||||
visual_decoder_file: None,
|
||||
textual_encoder_file: None,
|
||||
textual_decoder_file: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user