mirror of
https://github.com/mii443/usls.git
synced 2025-12-03 11:08:20 +00:00
Add SAM2 and ONNX (#28)
This commit is contained in:
@@ -38,6 +38,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
.with_model("sam-vit-b-decoder-u8.onnx")?;
|
.with_model("sam-vit-b-decoder-u8.onnx")?;
|
||||||
(options_encoder, options_decoder, "SAM")
|
(options_encoder, options_decoder, "SAM")
|
||||||
}
|
}
|
||||||
|
SamKind::Sam2 => {
|
||||||
|
let options_encoder = Options::default()
|
||||||
|
// .with_model("sam2-hiera-tiny-encoder.onnx")?;
|
||||||
|
// .with_model("sam2-hiera-small-encoder.onnx")?;
|
||||||
|
.with_model("sam2-hiera-base-plus-encoder.onnx")?;
|
||||||
|
let options_decoder = Options::default()
|
||||||
|
.with_i31((1, 1, 1).into())
|
||||||
|
.with_i41((1, 1, 1).into())
|
||||||
|
.with_sam_kind(SamKind::Sam2)
|
||||||
|
// .with_model("sam2-hiera-tiny-decoder.onnx")?;
|
||||||
|
// .with_model("sam2-hiera-small-decoder.onnx")?;
|
||||||
|
.with_model("sam2-hiera-base-plus-decoder.onnx")?;
|
||||||
|
(options_encoder, options_decoder, "SAM2")
|
||||||
|
}
|
||||||
SamKind::MobileSam => {
|
SamKind::MobileSam => {
|
||||||
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;
|
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
|
|||||||
#[derive(Debug, Clone, clap::ValueEnum)]
|
#[derive(Debug, Clone, clap::ValueEnum)]
|
||||||
pub enum SamKind {
|
pub enum SamKind {
|
||||||
Sam,
|
Sam,
|
||||||
|
Sam2,
|
||||||
MobileSam,
|
MobileSam,
|
||||||
SamHq,
|
SamHq,
|
||||||
EdgeSam,
|
EdgeSam,
|
||||||
@@ -94,7 +95,7 @@ impl SAM {
|
|||||||
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
|
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
|
||||||
options_decoder.use_low_res_mask.unwrap_or(false)
|
options_decoder.use_low_res_mask.unwrap_or(false)
|
||||||
}
|
}
|
||||||
SamKind::EdgeSam => true,
|
SamKind::EdgeSam | SamKind::Sam2 => true,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.dry_run()?;
|
encoder.dry_run()?;
|
||||||
@@ -142,9 +143,13 @@ impl SAM {
|
|||||||
xs0: &[DynamicImage],
|
xs0: &[DynamicImage],
|
||||||
prompts: &[SamPrompt],
|
prompts: &[SamPrompt],
|
||||||
) -> Result<Vec<Y>> {
|
) -> Result<Vec<Y>> {
|
||||||
let mut ys: Vec<Y> = Vec::new();
|
let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind {
|
||||||
|
SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])),
|
||||||
|
_ => (&xs[0], None, None),
|
||||||
|
};
|
||||||
|
|
||||||
for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() {
|
let mut ys: Vec<Y> = Vec::new();
|
||||||
|
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
|
||||||
let image_width = xs0[idx].width() as f32;
|
let image_width = xs0[idx].width() as f32;
|
||||||
let image_height = xs0[idx].height() as f32;
|
let image_height = xs0[idx].height() as f32;
|
||||||
let ratio =
|
let ratio =
|
||||||
@@ -180,6 +185,32 @@ impl SAM {
|
|||||||
prompts[idx].point_labels()?,
|
prompts[idx].point_labels()?,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
SamKind::Sam2 => {
|
||||||
|
vec![
|
||||||
|
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
|
||||||
|
X::from(
|
||||||
|
high_res_features_0
|
||||||
|
.unwrap()
|
||||||
|
.slice(s![idx, .., .., ..])
|
||||||
|
.into_dyn()
|
||||||
|
.into_owned(),
|
||||||
|
)
|
||||||
|
.insert_axis(0)?,
|
||||||
|
X::from(
|
||||||
|
high_res_features_1
|
||||||
|
.unwrap()
|
||||||
|
.slice(s![idx, .., .., ..])
|
||||||
|
.into_dyn()
|
||||||
|
.into_owned(),
|
||||||
|
)
|
||||||
|
.insert_axis(0)?,
|
||||||
|
prompts[idx].point_coords(ratio)?,
|
||||||
|
prompts[idx].point_labels()?,
|
||||||
|
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
||||||
|
X::zeros(&[1]), // has_mask_input
|
||||||
|
X::from(vec![image_height, image_width]), // orig_im_size
|
||||||
|
]
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let ys_ = self.decoder.run(args)?;
|
let ys_ = self.decoder.run(args)?;
|
||||||
@@ -196,6 +227,7 @@ impl SAM {
|
|||||||
(&ys_[2], &ys_[1])
|
(&ys_[2], &ys_[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
SamKind::Sam2 => (&ys_[0], &ys_[1]),
|
||||||
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
|
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
|
||||||
(2, 4) => (&ys_[1], &ys_[0]),
|
(2, 4) => (&ys_[1], &ys_[0]),
|
||||||
(4, 2) => (&ys_[0], &ys_[1]),
|
(4, 2) => (&ys_[0], &ys_[1]),
|
||||||
|
|||||||
Reference in New Issue
Block a user