mirror of
https://github.com/mii443/usls.git
synced 2025-12-03 11:08:20 +00:00
Add SAM2 and ONNX (#28)
This commit is contained in:
@@ -38,6 +38,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_model("sam-vit-b-decoder-u8.onnx")?;
|
||||
(options_encoder, options_decoder, "SAM")
|
||||
}
|
||||
SamKind::Sam2 => {
|
||||
let options_encoder = Options::default()
|
||||
// .with_model("sam2-hiera-tiny-encoder.onnx")?;
|
||||
// .with_model("sam2-hiera-small-encoder.onnx")?;
|
||||
.with_model("sam2-hiera-base-plus-encoder.onnx")?;
|
||||
let options_decoder = Options::default()
|
||||
.with_i31((1, 1, 1).into())
|
||||
.with_i41((1, 1, 1).into())
|
||||
.with_sam_kind(SamKind::Sam2)
|
||||
// .with_model("sam2-hiera-tiny-decoder.onnx")?;
|
||||
// .with_model("sam2-hiera-small-decoder.onnx")?;
|
||||
.with_model("sam2-hiera-base-plus-decoder.onnx")?;
|
||||
(options_encoder, options_decoder, "SAM2")
|
||||
}
|
||||
SamKind::MobileSam => {
|
||||
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
|
||||
#[derive(Debug, Clone, clap::ValueEnum)]
|
||||
pub enum SamKind {
|
||||
Sam,
|
||||
Sam2,
|
||||
MobileSam,
|
||||
SamHq,
|
||||
EdgeSam,
|
||||
@@ -94,7 +95,7 @@ impl SAM {
|
||||
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
|
||||
options_decoder.use_low_res_mask.unwrap_or(false)
|
||||
}
|
||||
SamKind::EdgeSam => true,
|
||||
SamKind::EdgeSam | SamKind::Sam2 => true,
|
||||
};
|
||||
|
||||
encoder.dry_run()?;
|
||||
@@ -142,9 +143,13 @@ impl SAM {
|
||||
xs0: &[DynamicImage],
|
||||
prompts: &[SamPrompt],
|
||||
) -> Result<Vec<Y>> {
|
||||
let mut ys: Vec<Y> = Vec::new();
|
||||
let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind {
|
||||
SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])),
|
||||
_ => (&xs[0], None, None),
|
||||
};
|
||||
|
||||
for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() {
|
||||
let mut ys: Vec<Y> = Vec::new();
|
||||
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
|
||||
let image_width = xs0[idx].width() as f32;
|
||||
let image_height = xs0[idx].height() as f32;
|
||||
let ratio =
|
||||
@@ -180,6 +185,32 @@ impl SAM {
|
||||
prompts[idx].point_labels()?,
|
||||
]
|
||||
}
|
||||
SamKind::Sam2 => {
|
||||
vec![
|
||||
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
|
||||
X::from(
|
||||
high_res_features_0
|
||||
.unwrap()
|
||||
.slice(s![idx, .., .., ..])
|
||||
.into_dyn()
|
||||
.into_owned(),
|
||||
)
|
||||
.insert_axis(0)?,
|
||||
X::from(
|
||||
high_res_features_1
|
||||
.unwrap()
|
||||
.slice(s![idx, .., .., ..])
|
||||
.into_dyn()
|
||||
.into_owned(),
|
||||
)
|
||||
.insert_axis(0)?,
|
||||
prompts[idx].point_coords(ratio)?,
|
||||
prompts[idx].point_labels()?,
|
||||
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
|
||||
X::zeros(&[1]), // has_mask_input
|
||||
X::from(vec![image_height, image_width]), // orig_im_size
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
let ys_ = self.decoder.run(args)?;
|
||||
@@ -196,6 +227,7 @@ impl SAM {
|
||||
(&ys_[2], &ys_[1])
|
||||
}
|
||||
}
|
||||
SamKind::Sam2 => (&ys_[0], &ys_[1]),
|
||||
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
|
||||
(2, 4) => (&ys_[1], &ys_[0]),
|
||||
(4, 2) => (&ys_[0], &ys_[1]),
|
||||
|
||||
Reference in New Issue
Block a user