diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 29607b7..21e4fcc 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -16,7 +16,7 @@ struct Args { scale: String, /// SAM kind - #[argh(option, default = "String::from(\"sam\")")] + #[argh(option, default = "String::from(\"samhq\")")] kind: String, } @@ -69,9 +69,19 @@ fn main() -> Result<()> { // Prompt let prompts = vec![ SamPrompt::default() - // .with_postive_point(500., 375.), // postive point - // .with_negative_point(774., 366.), // negative point - .with_bbox(215., 297., 643., 459.), // bbox + // // # demo: point + point + // .with_positive_point(500., 375.) // mid window + // .with_positive_point(1125., 625.), // car door + // // # demo: bbox + // .with_xyxy(425., 600., 700., 875.), // left wheel + // // Note: When specifying multiple boxes for multiple objects, only the last box is supported; all previous boxes will be ignored. + // .with_xyxy(75., 275., 1725., 850.) + // .with_xyxy(425., 600., 700., 875.) + // .with_xyxy(1240., 675., 1400., 750.) + // .with_xyxy(1375., 550., 1650., 800.) + // # demo: bbox + negative point + .with_xyxy(425., 600., 700., 875.) // left wheel + .with_negative_point(575., 750.), // tire ]; // Run & Annotate diff --git a/examples/sam2/README.md b/examples/sam2/README.md new file mode 100644 index 0000000..5a97486 --- /dev/null +++ b/examples/sam2/README.md @@ -0,0 +1,6 @@ +## Quick Start + +```Shell + +cargo run -r -F cuda --example sam -- --device cuda --scale t +``` \ No newline at end of file diff --git a/examples/sam2/main.rs b/examples/sam2/main.rs new file mode 100644 index 0000000..e8722a4 --- /dev/null +++ b/examples/sam2/main.rs @@ -0,0 +1,93 @@ +use anyhow::Result; +use usls::{ + models::{SamPrompt, SAM2}, + Annotator, DataLoader, Options, Scale, +}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, + + /// scale + #[argh(option, default = "String::from(\"t\")")] + scale: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // Build model + let (options_encoder, options_decoder) = match args.scale.as_str().try_into()? { + Scale::T => ( + Options::sam2_1_tiny_encoder(), + Options::sam2_1_tiny_decoder(), + ), + Scale::S => ( + Options::sam2_1_small_encoder(), + Options::sam2_1_small_decoder(), + ), + Scale::B => ( + Options::sam2_1_base_plus_encoder(), + Options::sam2_1_base_plus_decoder(), + ), + Scale::L => ( + Options::sam2_1_large_encoder(), + Options::sam2_1_large_decoder(), + ), + _ => unimplemented!("Unsupported model scale: {:?}. Try b, s, t, l.", args.scale), + }; + + let options_encoder = options_encoder + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let options_decoder = options_decoder + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut model = SAM2::new(options_encoder, options_decoder)?; + + // Load image + let xs = DataLoader::try_read_n(&["images/truck.jpg"])?; + + // Prompt + let prompts = vec![SamPrompt::default() + // // # demo: point + point + // .with_positive_point(500., 375.) // mid window + // .with_positive_point(1125., 625.), // car door + // // # demo: bbox + // .with_xyxy(425., 600., 700., 875.), // left wheel + // // # demo: bbox + negative point + // .with_xyxy(425., 600., 700., 875.) // left wheel + // .with_negative_point(575., 750.), // tire + // # demo: multiple objects with boxes + .with_xyxy(75., 275., 1725., 850.) + .with_xyxy(425., 600., 700., 875.) + .with_xyxy(1375., 550., 1650., 800.) + .with_xyxy(1240., 675., 1400., 750.)]; + + // Run & Annotate + let ys = model.forward(&xs, &prompts)?; + + // annotate + let annotator = Annotator::default() + .with_mask_style(usls::Style::mask().with_draw_mask_polygon_largest(true)); + + for (x, y) in xs.iter().zip(ys.iter()) { + annotator.annotate(x, y)?.save(format!( + "{}.jpg", + usls::Dir::Current + .base_dir_with_subs(&["runs", model.spec()])? + .join(usls::timestamp(None)) + .display(), + ))?; + } + + Ok(()) +} diff --git a/examples/yolo-sam/main.rs b/examples/yolo-sam/main.rs deleted file mode 100644 index 74c35d6..0000000 --- a/examples/yolo-sam/main.rs +++ /dev/null @@ -1,73 +0,0 @@ -use anyhow::Result; -use usls::{ - models::{SamPrompt, SAM, YOLO}, - Annotator, DataLoader, Options, Scale, Style, -}; - -#[derive(argh::FromArgs)] -/// Example -struct Args { - /// device - #[argh(option, default = "String::from(\"cpu:0\")")] - device: String, -} - -fn main() -> Result<()> { - tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) - .init(); - - let args: Args = argh::from_env(); - - // build SAM - let (options_encoder, options_decoder) = ( - Options::mobile_sam_tiny_encoder().commit()?, - Options::mobile_sam_tiny_decoder().commit()?, - ); - let mut sam = SAM::new(options_encoder, options_decoder)?; - - // build YOLOv8 - let options_yolo = Options::yolo_detect() - .with_model_scale(Scale::N) - .with_model_version(8.into()) - .with_model_device(args.device.as_str().try_into()?) - .commit()?; - let mut yolo = YOLO::new(options_yolo)?; - - // load one image - let xs = DataLoader::try_read_n(&["images/dog.jpg"])?; - - // build annotator - let annotator = Annotator::default().with_hbb_style(Style::hbb().with_draw_fill(true)); - - // run & annotate - let ys_det = yolo.forward(&xs)?; - for y_det in ys_det.iter() { - if let Some(hbbs) = y_det.hbbs() { - for hbb in hbbs { - let ys_sam = sam.forward( - &xs, - &[SamPrompt::default().with_bbox( - hbb.xmin(), - hbb.ymin(), - hbb.xmax(), - hbb.ymax(), - )], - )?; - // annotator.annotate(&xs, &ys_sam); - for (x, y) in xs.iter().zip(ys_sam.iter()) { - annotator.annotate(x, y)?.save(format!( - "{}.jpg", - usls::Dir::Current - .base_dir_with_subs(&["runs", "YOLO-SAM"])? - .join(usls::timestamp(None)) - .display(), - ))?; - } - } - } - } - - Ok(()) -} diff --git a/examples/yolo-sam/README.md b/examples/yolo-sam2/README.md similarity index 100% rename from examples/yolo-sam/README.md rename to examples/yolo-sam2/README.md diff --git a/examples/yolo-sam2/main.rs b/examples/yolo-sam2/main.rs new file mode 100644 index 0000000..bcf9634 --- /dev/null +++ b/examples/yolo-sam2/main.rs @@ -0,0 +1,79 @@ +use anyhow::Result; +use usls::{ + models::{SamPrompt, SAM2, YOLO}, + Annotator, DataLoader, Options, Scale, Style, +}; + +#[derive(argh::FromArgs)] +/// Example +struct Args { + /// device + #[argh(option, default = "String::from(\"cpu:0\")")] + device: String, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) + .init(); + + let args: Args = argh::from_env(); + + // build SAM + let (options_encoder, options_decoder) = ( + Options::sam2_1_tiny_encoder().commit()?, + Options::sam2_1_tiny_decoder().commit()?, + ); + let mut sam = SAM2::new(options_encoder, options_decoder)?; + + // build YOLOv8 + let options_yolo = Options::yolo_detect() + .with_model_scale(Scale::N) + .with_model_version(8.into()) + .with_model_device(args.device.as_str().try_into()?) + .commit()?; + let mut yolo = YOLO::new(options_yolo)?; + + // load one image + let xs = DataLoader::try_read_n(&["./assets/bus.jpg"])?; + + // build annotator + let annotator = Annotator::default() + .with_polygon_style( + Style::polygon() + .with_visible(true) + .with_text_visible(true) + .show_id(true) + .show_name(true), + ) + .with_mask_style(Style::mask().with_draw_mask_polygon_largest(true)); + + // run & annotate + let ys_det = yolo.forward(&xs)?; + for y_det in ys_det.iter() { + if let Some(hbbs) = y_det.hbbs() { + // collect hhbs + let mut prompt = SamPrompt::default(); + for hbb in hbbs { + prompt = prompt.with_xyxy(hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax()); + } + + // sam2 infer + let ys_sam = sam.forward(&xs, &[prompt])?; + + // annotate + for (x, y) in xs.iter().zip(ys_sam.iter()) { + annotator.annotate(x, y)?.save(format!( + "{}.jpg", + usls::Dir::Current + .base_dir_with_subs(&["runs", "YOLO-SAM2"])? + .join(usls::timestamp(None)) + .display(), + ))?; + } + } + } + + Ok(()) +} diff --git a/src/models/mod.rs b/src/models/mod.rs index bbab7a8..ad5f42b 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -24,6 +24,7 @@ mod rfdetr; mod rtdetr; mod rtmo; mod sam; +mod sam2; mod sapiens; mod slanet; mod smolvlm; @@ -49,6 +50,7 @@ pub use rfdetr::*; pub use rtdetr::*; pub use rtmo::*; pub use sam::*; +pub use sam2::*; pub use sapiens::*; pub use slanet::*; pub use smolvlm::*; diff --git a/src/models/sam/impl.rs b/src/models/sam/impl.rs index 1fcde3e..0e02ca5 100644 --- a/src/models/sam/impl.rs +++ b/src/models/sam/impl.rs @@ -1,16 +1,17 @@ use aksr::Builder; use anyhow::Result; -use ndarray::{s, Array, Axis}; +use ndarray::{s, Axis}; use rand::prelude::*; use crate::{ - elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, Ts, Xs, X, Y, + elapsed, DynConf, Engine, Image, Mask, Ops, Options, Polygon, Processor, SamPrompt, Ts, Xs, X, + Y, }; #[derive(Debug, Clone)] pub enum SamKind { Sam, - Sam2, + Sam2, // 2.0 MobileSam, SamHq, EdgeSam, @@ -31,54 +32,6 @@ impl TryFrom<&str> for SamKind { } } -#[derive(Debug, Default, Clone)] -pub struct SamPrompt { - points: Vec, - labels: Vec, -} - -impl SamPrompt { - pub fn everything() -> Self { - todo!() - } - - pub fn with_postive_point(mut self, x: f32, y: f32) -> Self { - self.points.extend_from_slice(&[x, y]); - self.labels.push(1.); - self - } - - pub fn with_negative_point(mut self, x: f32, y: f32) -> Self { - self.points.extend_from_slice(&[x, y]); - self.labels.push(0.); - self - } - - pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self { - self.points.extend_from_slice(&[x, y, x2, y2]); - self.labels.extend_from_slice(&[2., 3.]); - self - } - - pub fn point_coords(&self, r: f32) -> Result { - let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())? - .into_dyn() - .into_owned(); - Ok(X::from(point_coords * r)) - } - - pub fn point_labels(&self) -> Result { - let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())? - .into_dyn() - .into_owned(); - Ok(X::from(point_labels)) - } - - pub fn num_points(&self) -> usize { - self.points.len() / 2 - } -} - #[derive(Builder, Debug)] pub struct SAM { encoder: Engine, @@ -167,14 +120,28 @@ impl SAM { ); let ratio = self.processor.images_transform_info[idx].height_scale; + let (mut point_coords, mut point_labels) = ( + prompts[idx].point_coords(ratio)?, + prompts[idx].point_labels()?, + ); + + if point_coords.shape()[0] != 1 { + point_coords = X::from(point_coords.slice(s![-1, .., ..]).to_owned().into_dyn()) + .insert_axis(0)?; + } + if point_labels.shape()[0] != 1 { + point_labels = X::from(point_labels.slice(s![-1, ..,]).to_owned().into_dyn()) + .insert_axis(0)?; + } + let args = match self.kind { SamKind::Sam | SamKind::MobileSam => { vec![ X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? .repeat(0, self.batch)?, // image_embedding - prompts[idx].point_coords(ratio)?, // point_coords - prompts[idx].point_labels()?, // point_labels + point_coords, + point_labels, X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input, X::zeros(&[1]), // has_mask_input X::from(vec![image_height as _, image_width as _]), // orig_im_size @@ -189,8 +156,8 @@ impl SAM { .insert_axis(0)? .insert_axis(0)? .repeat(0, self.batch)?, // intern_embedding - prompts[idx].point_coords(ratio)?, // point_coords - prompts[idx].point_labels()?, // point_labels + point_coords, + point_labels, X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input X::zeros(&[1]), // has_mask_input X::from(vec![image_height as _, image_width as _]), // orig_im_size @@ -201,8 +168,8 @@ impl SAM { X::from(image_embedding.into_dyn().into_owned()) .insert_axis(0)? .repeat(0, self.batch)?, - prompts[idx].point_coords(ratio)?, - prompts[idx].point_labels()?, + point_coords, + point_labels, ] } SamKind::Sam2 => { @@ -228,11 +195,11 @@ impl SAM { ) .insert_axis(0)? .repeat(0, self.batch)?, - prompts[idx].point_coords(ratio)?, - prompts[idx].point_labels()?, - X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input - X::zeros(&[1]), // has_mask_input - X::from(vec![image_height as _, image_width as _]), // orig_im_size + point_coords, + point_labels, + X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), + X::zeros(&[1]), + X::from(vec![image_height as _, image_width as _]), ] } }; diff --git a/src/models/sam/mod.rs b/src/models/sam/mod.rs index fbd2b75..bce941d 100644 --- a/src/models/sam/mod.rs +++ b/src/models/sam/mod.rs @@ -2,3 +2,88 @@ mod config; mod r#impl; pub use r#impl::*; + +#[derive(Debug, Default, Clone)] +pub struct SamPrompt { + pub coords: Vec>, + pub labels: Vec>, +} + +impl SamPrompt { + pub fn point_coords(&self, ratio: f32) -> anyhow::Result { + // [num_labels,num_points,2] + let num_labels = self.coords.len(); + let num_points = if num_labels > 0 { + self.coords[0].len() + } else { + 0 + }; + let flat: Vec = self + .coords + .iter() + .flat_map(|v| v.iter().flat_map(|&[x, y]| [x, y])) + .collect(); + let y = ndarray::Array3::from_shape_vec((num_labels, num_points, 2), flat)?.into_dyn(); + + Ok((y * ratio).into()) + } + + pub fn point_labels(&self) -> anyhow::Result { + // [num_labels,num_points] + let num_labels = self.labels.len(); + let num_points = if num_labels > 0 { + self.labels[0].len() + } else { + 0 + }; + let flat: Vec = self.labels.iter().flat_map(|v| v.iter().copied()).collect(); + let y = ndarray::Array2::from_shape_vec((num_labels, num_points), flat)?.into_dyn(); + Ok(y.into()) + } + + pub fn with_xyxy(mut self, x1: f32, y1: f32, x2: f32, y2: f32) -> Self { + // TODO: if already has points, push_front coords + self.coords.push(vec![[x1, y1], [x2, y2]]); + self.labels.push(vec![2., 3.]); + + self + } + + pub fn with_positive_point(mut self, x: f32, y: f32) -> Self { + self = self.add_point(x, y, 1.); + self + } + + pub fn with_negative_point(mut self, x: f32, y: f32) -> Self { + self = self.add_point(x, y, 0.); + self + } + + fn add_point(mut self, x: f32, y: f32, id: f32) -> Self { + if self.coords.is_empty() { + self.coords.push(vec![[x, y]]); + self.labels.push(vec![id]); + } else { + if let Some(last) = self.coords.last_mut() { + last.extend_from_slice(&[[x, y]]); + } + + if let Some(last) = self.labels.last_mut() { + last.extend_from_slice(&[id]); + } + } + self + } + + pub fn with_positive_point_object(mut self, x: f32, y: f32) -> Self { + self.coords.push(vec![[x, y]]); + self.labels.push(vec![1.]); + self + } + + pub fn with_negative_point_object(mut self, x: f32, y: f32) -> Self { + self.coords.push(vec![[x, y]]); + self.labels.push(vec![0.]); + self + } +} diff --git a/src/models/sam2/README.md b/src/models/sam2/README.md new file mode 100644 index 0000000..e3e9a5f --- /dev/null +++ b/src/models/sam2/README.md @@ -0,0 +1,10 @@ +# Segment Anything Model + +## Official Repository + +The official repository can be found on [sam2](https://github.com/facebookresearch/sam2) + + +## Example + +Refer to the [example](../../../examples/sam2) diff --git a/src/models/sam2/config.rs b/src/models/sam2/config.rs new file mode 100644 index 0000000..db9df28 --- /dev/null +++ b/src/models/sam2/config.rs @@ -0,0 +1,50 @@ +use crate::Options; + +/// Model configuration for `SAM2.1` +impl Options { + pub fn sam2_encoder() -> Self { + Self::sam() + .with_model_ixx(0, 2, 1024.into()) + .with_model_ixx(0, 3, 1024.into()) + .with_resize_mode(crate::ResizeMode::FitAdaptive) + .with_resize_filter("Bilinear") + .with_image_mean(&[0.485, 0.456, 0.406]) + .with_image_std(&[0.229, 0.224, 0.225]) + } + + pub fn sam2_decoder() -> Self { + Self::sam() + } + + pub fn sam2_1_tiny_encoder() -> Self { + Self::sam2_encoder().with_model_file("sam2.1-hiera-tiny-encoder.onnx") + } + + pub fn sam2_1_tiny_decoder() -> Self { + Self::sam2_decoder().with_model_file("sam2.1-hiera-tiny-decoder.onnx") + } + + pub fn sam2_1_small_encoder() -> Self { + Self::sam2_encoder().with_model_file("sam2.1-hiera-small-encoder.onnx") + } + + pub fn sam2_1_small_decoder() -> Self { + Self::sam2_decoder().with_model_file("sam2.1-hiera-small-decoder.onnx") + } + + pub fn sam2_1_base_plus_encoder() -> Self { + Self::sam2_encoder().with_model_file("sam2.1-hiera-base-plus-encoder.onnx") + } + + pub fn sam2_1_base_plus_decoder() -> Self { + Self::sam2_decoder().with_model_file("sam2.1-hiera-base-plus-decoder.onnx") + } + + pub fn sam2_1_large_encoder() -> Self { + Self::sam2_encoder().with_model_file("sam2.1-hiera-large-encoder.onnx") + } + + pub fn sam2_1_large_decoder() -> Self { + Self::sam2_decoder().with_model_file("sam2.1-hiera-large-decoder.onnx") + } +} diff --git a/src/models/sam2/impl.rs b/src/models/sam2/impl.rs new file mode 100644 index 0000000..cb32dce --- /dev/null +++ b/src/models/sam2/impl.rs @@ -0,0 +1,164 @@ +use aksr::Builder; +use anyhow::Result; +use ndarray::{s, Axis}; + +use crate::{ + elapsed, DynConf, Engine, Image, Mask, Ops, Options, Processor, SamPrompt, Ts, Xs, X, Y, +}; + +#[derive(Builder, Debug)] +pub struct SAM2 { + encoder: Engine, + decoder: Engine, + height: usize, + width: usize, + batch: usize, + processor: Processor, + conf: DynConf, + ts: Ts, + spec: String, +} + +impl SAM2 { + pub fn new(options_encoder: Options, options_decoder: Options) -> Result { + let encoder = options_encoder.to_engine()?; + let decoder = options_decoder.to_engine()?; + let (batch, height, width) = ( + encoder.batch().opt(), + encoder.try_height().unwrap_or(&1024.into()).opt(), + encoder.try_width().unwrap_or(&1024.into()).opt(), + ); + let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); + let spec = encoder.spec().to_owned(); + let processor = options_encoder + .to_processor()? + .with_image_width(width as _) + .with_image_height(height as _); + let conf = DynConf::new(options_encoder.class_confs(), 1); + + Ok(Self { + encoder, + decoder, + conf, + batch, + height, + width, + ts, + processor, + spec, + }) + } + + pub fn forward(&mut self, xs: &[Image], prompts: &[SamPrompt]) -> Result> { + let ys = elapsed!("encode", self.ts, { self.encode(xs)? }); + let ys = elapsed!("decode", self.ts, { self.decode(&ys, prompts)? }); + + Ok(ys) + } + + pub fn encode(&mut self, xs: &[Image]) -> Result { + let xs_ = self.processor.process_images(xs)?; + self.encoder.run(Xs::from(xs_)) + } + + pub fn decode(&mut self, xs: &Xs, prompts: &[SamPrompt]) -> Result> { + let (image_embeddings, high_res_features_0, high_res_features_1) = (&xs[0], &xs[1], &xs[2]); + + let mut ys: Vec = Vec::new(); + for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() { + let (image_height, image_width) = ( + self.processor.images_transform_info[idx].height_src, + self.processor.images_transform_info[idx].width_src, + ); + let ratio = self.processor.images_transform_info[idx].height_scale; + + let ys_ = self.decoder.run(Xs::from(vec![ + X::from(image_embedding.into_dyn().into_owned()) + .insert_axis(0)? + .repeat(0, self.batch)?, + X::from( + high_res_features_0 + .slice(s![idx, .., .., ..]) + .into_dyn() + .into_owned(), + ) + .insert_axis(0)? + .repeat(0, self.batch)?, + X::from( + high_res_features_1 + .slice(s![idx, .., .., ..]) + .into_dyn() + .into_owned(), + ) + .insert_axis(0)? + .repeat(0, self.batch)?, + prompts[idx].point_coords(ratio)?, + prompts[idx].point_labels()?, + // TODO + X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), + X::zeros(&[1]), + X::from(vec![self.width as _, self.height as _]), + ]))?; + + let mut y_masks: Vec = Vec::new(); + + // masks & confs + let (masks, confs) = (&ys_[0], &ys_[1]); + + for (id, (mask, iou)) in masks + .axis_iter(Axis(0)) + .zip(confs.axis_iter(Axis(0))) + .enumerate() + { + let (i, conf) = match iou + .to_owned() + .into_raw_vec_and_offset() + .0 + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(&b.1)) + { + Some((i, c)) => (i, c), + None => continue, + }; + + if conf < self.conf[0] { + continue; + } + let mask = mask.slice(s![i, .., ..]); + + let (h, w) = mask.dim(); + let luma = Ops::resize_lumaf32_u8( + &mask.into_owned().into_raw_vec_and_offset().0, + w as _, + h as _, + image_width as _, + image_height as _, + true, + "Bilinear", + )?; + + // contours + let mask = Mask::new(&luma, image_width, image_height)?.with_id(id); + y_masks.push(mask); + } + + let mut y = Y::default(); + if !y_masks.is_empty() { + y = y.with_masks(&y_masks); + } + + ys.push(y); + } + + Ok(ys) + } + + pub fn width_low_res(&self) -> usize { + self.width / 4 + } + + pub fn height_low_res(&self) -> usize { + self.height / 4 + } +} diff --git a/src/models/sam2/mod.rs b/src/models/sam2/mod.rs new file mode 100644 index 0000000..fbd2b75 --- /dev/null +++ b/src/models/sam2/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod r#impl; + +pub use r#impl::*; diff --git a/src/utils/options.rs b/src/utils/options.rs index 3db49b6..613a620 100644 --- a/src/utils/options.rs +++ b/src/utils/options.rs @@ -27,6 +27,14 @@ pub struct Options { pub trt_fp16: bool, pub profile: bool, + // models + pub model_encoder_file: Option, + pub model_decoder_file: Option, + pub visual_encoder_file: Option, + pub visual_decoder_file: Option, + pub textual_encoder_file: Option, + pub textual_decoder_file: Option, + // Processor configs #[args(except(setter))] pub image_width: u32, @@ -113,8 +121,8 @@ pub struct Options { pub binary_thresh: Option, // For SAM - pub sam_kind: Option, - pub low_res_mask: Option, + pub sam_kind: Option, // TODO: remove + pub low_res_mask: Option, // TODO: remove // Others pub ort_graph_opt_level: Option, @@ -203,6 +211,12 @@ impl Default for Options { topk_2: None, topk_3: None, ort_graph_opt_level: None, + model_encoder_file: None, + model_decoder_file: None, + visual_encoder_file: None, + visual_decoder_file: None, + textual_encoder_file: None, + textual_decoder_file: None, } } }