Add SAM2 and ONNX (#28)

This commit is contained in:
Jamjamjon
2024-08-01 17:26:06 +08:00
committed by GitHub
parent 451aa8cc7b
commit 46a4456a38
2 changed files with 49 additions and 3 deletions

View File

@@ -38,6 +38,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_model("sam-vit-b-decoder-u8.onnx")?; .with_model("sam-vit-b-decoder-u8.onnx")?;
(options_encoder, options_decoder, "SAM") (options_encoder, options_decoder, "SAM")
} }
SamKind::Sam2 => {
let options_encoder = Options::default()
// .with_model("sam2-hiera-tiny-encoder.onnx")?;
// .with_model("sam2-hiera-small-encoder.onnx")?;
.with_model("sam2-hiera-base-plus-encoder.onnx")?;
let options_decoder = Options::default()
.with_i31((1, 1, 1).into())
.with_i41((1, 1, 1).into())
.with_sam_kind(SamKind::Sam2)
// .with_model("sam2-hiera-tiny-decoder.onnx")?;
// .with_model("sam2-hiera-small-decoder.onnx")?;
.with_model("sam2-hiera-base-plus-decoder.onnx")?;
(options_encoder, options_decoder, "SAM2")
}
SamKind::MobileSam => { SamKind::MobileSam => {
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?; let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;

View File

@@ -8,6 +8,7 @@ use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
#[derive(Debug, Clone, clap::ValueEnum)] #[derive(Debug, Clone, clap::ValueEnum)]
pub enum SamKind { pub enum SamKind {
Sam, Sam,
Sam2,
MobileSam, MobileSam,
SamHq, SamHq,
EdgeSam, EdgeSam,
@@ -94,7 +95,7 @@ impl SAM {
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
options_decoder.use_low_res_mask.unwrap_or(false) options_decoder.use_low_res_mask.unwrap_or(false)
} }
SamKind::EdgeSam => true, SamKind::EdgeSam | SamKind::Sam2 => true,
}; };
encoder.dry_run()?; encoder.dry_run()?;
@@ -142,9 +143,13 @@ impl SAM {
xs0: &[DynamicImage], xs0: &[DynamicImage],
prompts: &[SamPrompt], prompts: &[SamPrompt],
) -> Result<Vec<Y>> { ) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new(); let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind {
SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])),
_ => (&xs[0], None, None),
};
for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() { let mut ys: Vec<Y> = Vec::new();
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
let image_width = xs0[idx].width() as f32; let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32; let image_height = xs0[idx].height() as f32;
let ratio = let ratio =
@@ -180,6 +185,32 @@ impl SAM {
prompts[idx].point_labels()?, prompts[idx].point_labels()?,
] ]
} }
SamKind::Sam2 => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
X::from(
high_res_features_0
.unwrap()
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
X::from(
high_res_features_1
.unwrap()
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height, image_width]), // orig_im_size
]
}
}; };
let ys_ = self.decoder.run(args)?; let ys_ = self.decoder.run(args)?;
@@ -196,6 +227,7 @@ impl SAM {
(&ys_[2], &ys_[1]) (&ys_[2], &ys_[1])
} }
} }
SamKind::Sam2 => (&ys_[0], &ys_[1]),
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) { SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
(2, 4) => (&ys_[1], &ys_[0]), (2, 4) => (&ys_[1], &ys_[0]),
(4, 2) => (&ys_[0], &ys_[1]), (4, 2) => (&ys_[0], &ys_[1]),