From 53d14ee2fb686ff7c95ed7e51eda79b700935b2f Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Sat, 3 Aug 2024 18:03:35 +0800 Subject: [PATCH] Add `Xs`, a wrapper over `Vec` (#29) --- Cargo.toml | 2 +- src/core/mod.rs | 6 +- src/core/{engine.rs => ort_engine.rs} | 12 +-- src/core/vision.rs | 8 +- src/core/x.rs | 16 ++- src/core/xs.rs | 113 +++++++++++++++++++++ src/models/blip.rs | 10 +- src/models/clip.rs | 6 +- src/models/db.rs | 6 +- src/models/depth_anything.rs | 6 +- src/models/dinov2.rs | 4 +- src/models/mod.rs | 2 - src/models/modnet.rs | 6 +- src/models/rtdetr.rs | 140 -------------------------- src/models/rtmo.rs | 6 +- src/models/sam.rs | 19 ++-- src/models/svtr.rs | 6 +- src/models/yolo.rs | 10 +- src/models/yolop.rs | 7 +- src/utils/mod.rs | 9 ++ 20 files changed, 193 insertions(+), 201 deletions(-) rename src/core/{engine.rs => ort_engine.rs} (99%) create mode 100644 src/core/xs.rs delete mode 100644 src/models/rtdetr.rs diff --git a/Cargo.toml b/Cargo.toml index 72c7d0c..892f3f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.8" +version = "0.0.9" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" diff --git a/src/core/mod.rs b/src/core/mod.rs index 3bc2d7a..0577c6a 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,29 +2,31 @@ mod annotator; mod dataloader; mod device; mod dynconf; -mod engine; mod logits_sampler; mod metric; mod min_opt_max; pub mod onnx; pub mod ops; mod options; +mod ort_engine; mod tokenizer_stream; mod ts; mod vision; mod x; +mod xs; pub use annotator::Annotator; pub use dataloader::DataLoader; pub use device::Device; pub use dynconf::DynConf; -pub use engine::OrtEngine; pub use logits_sampler::LogitsSampler; pub use metric::Metric; pub use min_opt_max::MinOptMax; pub use ops::Ops; pub use options::Options; +pub use ort_engine::OrtEngine; pub use tokenizer_stream::TokenizerStream; pub use ts::Ts; pub use vision::Vision; pub use x::X; +pub use xs::Xs; diff --git a/src/core/engine.rs b/src/core/ort_engine.rs similarity index 99% rename from src/core/engine.rs rename to src/core/ort_engine.rs index 4cda4d9..a30c208 100644 --- a/src/core/engine.rs +++ b/src/core/ort_engine.rs @@ -8,7 +8,7 @@ use ort::{ use prost::Message; use std::collections::HashSet; -use crate::{home_dir, onnx, Device, MinOptMax, Ops, Options, Ts, CHECK_MARK, CROSS_MARK, X}; +use crate::{home_dir, onnx, Device, MinOptMax, Ops, Options, Ts, Xs, CHECK_MARK, CROSS_MARK, X}; /// Ort Tensor Attrs: name, data_type, dims #[derive(Debug)] @@ -288,6 +288,7 @@ impl OrtEngine { let x: Array = Array::ones(x).into_dyn(); xs.push(X::from(x)); } + let xs = Xs::from(xs); for _ in 0..self.num_dry_run { // self.run(xs.as_ref())?; self.run(xs.clone())?; @@ -298,11 +299,11 @@ impl OrtEngine { Ok(()) } - pub fn run(&mut self, xs: Vec) -> Result> { + pub fn run(&mut self, xs: Xs) -> Result { // inputs dtype alignment let mut xs_ = Vec::new(); let t_pre = std::time::Instant::now(); - for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.iter()) { + for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { let x_ = match &idtype { TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(), TensorElementType::Float16 => { @@ -334,7 +335,7 @@ impl OrtEngine { self.ts.add_or_push(1, t_run); // oputput - let mut ys = Vec::new(); + let mut ys = Xs::new(); let t_post = std::time::Instant::now(); for (dtype, name) in self .outputs_attrs @@ -358,8 +359,7 @@ impl OrtEngine { .into_owned(), _ => todo!(), }; - // ys.push(y_); - ys.push(X::from(y_)); + ys.push_kv(name.as_str(), X::from(y_))?; } let t_post = t_post.elapsed(); self.ts.add_or_push(2, t_post); diff --git a/src/core/vision.rs b/src/core/vision.rs index b3dad90..f7fe6ed 100644 --- a/src/core/vision.rs +++ b/src/core/vision.rs @@ -1,4 +1,4 @@ -use crate::{Options, X, Y}; +use crate::{Options, Xs, Y}; pub trait Vision: Sized { type Input; // DynamicImage @@ -7,13 +7,13 @@ pub trait Vision: Sized { fn new(options: Options) -> anyhow::Result; /// Preprocesses the input data. - fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result>; + fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result; /// Executes the model on the preprocessed data. - fn inference(&mut self, xs: Vec) -> anyhow::Result>; + fn inference(&mut self, xs: Xs) -> anyhow::Result; /// Postprocesses the model's output. - fn postprocess(&self, xs: Vec, xs0: &[Self::Input]) -> anyhow::Result>; + fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> anyhow::Result>; /// Executes the full pipeline. fn run(&mut self, xs: &[Self::Input]) -> anyhow::Result> { diff --git a/src/core/x.rs b/src/core/x.rs index e6b39ba..beb026b 100644 --- a/src/core/x.rs +++ b/src/core/x.rs @@ -4,7 +4,7 @@ use ndarray::{Array, Dim, IxDyn, IxDynImpl}; use crate::Ops; -/// Model input, alias for [`Array`] +/// Model input, wrapper over [`Array`] #[derive(Debug, Clone, Default)] pub struct X(pub Array); @@ -30,7 +30,11 @@ impl std::ops::Deref for X { impl X { pub fn zeros(shape: &[usize]) -> Self { - Self(Array::zeros(Dim(IxDynImpl::from(shape.to_vec())))) + Self::from(Array::zeros(Dim(IxDynImpl::from(shape.to_vec())))) + } + + pub fn ones(shape: &[usize]) -> Self { + Self::from(Array::ones(Dim(IxDynImpl::from(shape.to_vec())))) } pub fn apply(ops: &[Ops]) -> Result { @@ -77,6 +81,10 @@ impl X { self.0.shape() } + pub fn ndim(&self) -> usize { + self.0.ndim() + } + pub fn normalize(mut self, min_: f32, max_: f32) -> Result { self.0 = Ops::normalize(self.0, min_, max_)?; Ok(self) @@ -93,7 +101,7 @@ impl X { } pub fn resize(xs: &[DynamicImage], height: u32, width: u32, filter: &str) -> Result { - Ok(Self(Ops::resize(xs, height, width, filter)?)) + Ok(Self::from(Ops::resize(xs, height, width, filter)?)) } pub fn letterbox( @@ -105,7 +113,7 @@ impl X { resize_by: &str, center: bool, ) -> Result { - Ok(Self(Ops::letterbox( + Ok(Self::from(Ops::letterbox( xs, height, width, filter, bg, resize_by, center, )?)) } diff --git a/src/core/xs.rs b/src/core/xs.rs new file mode 100644 index 0000000..8b6a11c --- /dev/null +++ b/src/core/xs.rs @@ -0,0 +1,113 @@ +use anyhow::Result; +use std::collections::HashMap; +use std::ops::{Deref, Index}; + +use crate::{string_random, X}; + +#[derive(Debug, Default, Clone)] +pub struct Xs { + map: HashMap, + names: Vec, +} + +impl From for Xs { + fn from(x: X) -> Self { + let mut xs = Self::default(); + xs.push(x); + xs + } +} + +impl From> for Xs { + fn from(xs: Vec) -> Self { + let mut ys = Self::default(); + for x in xs { + ys.push(x); + } + ys + } +} + +impl Xs { + pub fn new() -> Self { + Self { + ..Default::default() + } + } + + pub fn push(&mut self, value: X) { + loop { + let key = string_random(5); + if !self.map.contains_key(&key) { + self.names.push(key.to_string()); + self.map.insert(key.to_string(), value); + break; + } + } + } + + pub fn push_kv(&mut self, key: &str, value: X) -> Result<()> { + if !self.map.contains_key(key) { + self.names.push(key.to_string()); + self.map.insert(key.to_string(), value); + Ok(()) + } else { + anyhow::bail!("Xs already contains key: {:?}", key) + } + } + + pub fn names(&self) -> &Vec { + &self.names + } +} + +impl Deref for Xs { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl Index<&str> for Xs { + type Output = X; + + fn index(&self, index: &str) -> &Self::Output { + self.map.get(index).expect("Index was not found in `Xs`") + } +} + +impl Index for Xs { + type Output = X; + + fn index(&self, index: usize) -> &Self::Output { + self.names + .get(index) + .and_then(|key| self.map.get(key)) + .expect("Index was not found in `Xs`") + } +} + +pub struct XsIter<'a> { + inner: std::vec::IntoIter<&'a X>, +} + +impl<'a> Iterator for XsIter<'a> { + type Item = &'a X; + + fn next(&mut self) -> Option { + self.inner.next() + } +} + +impl<'a> IntoIterator for &'a Xs { + type Item = &'a X; + type IntoIter = XsIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + let values: Vec<&X> = self.names.iter().map(|x| &self.map[x]).collect(); + XsIter { + inner: values.into_iter(), + } + } +} diff --git a/src/models/blip.rs b/src/models/blip.rs index d70280d..15e36d3 100644 --- a/src/models/blip.rs +++ b/src/models/blip.rs @@ -4,7 +4,9 @@ use ndarray::{s, Array, Axis, IxDyn}; use std::io::Write; use tokenizers::Tokenizer; -use crate::{Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, X, Y}; +use crate::{ + Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y, +}; #[derive(Debug)] pub struct Blip { @@ -58,7 +60,7 @@ impl Blip { ), Ops::Nhwc2nchw, ])?; - let ys = self.visual.run(vec![xs_])?; + let ys = self.visual.run(Xs::from(xs_))?; Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) } @@ -108,12 +110,12 @@ impl Blip { Array::ones(input_ids_nd.shape()).into_dyn(); let input_ids_attn_mask = X::from(input_ids_attn_mask); - let y = self.textual.run(vec![ + let y = self.textual.run(Xs::from(vec![ input_ids_nd, input_ids_attn_mask, X::from(image_embeds.data().to_owned()), X::from(image_embeds_attn_mask.to_owned()), - ])?; // N, length, vocab_size + ]))?; // N, length, vocab_size let y = y[0].slice(s!(0, -1.., ..)); let logits = y.slice(s!(0, ..)).to_vec(); let token_id = logits_sampler.decode(&logits)?; diff --git a/src/models/clip.rs b/src/models/clip.rs index 73802e1..14a1392 100644 --- a/src/models/clip.rs +++ b/src/models/clip.rs @@ -3,7 +3,7 @@ use image::DynamicImage; use ndarray::Array2; use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y}; +use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; #[derive(Debug)] pub struct Clip { @@ -69,7 +69,7 @@ impl Clip { ), Ops::Nhwc2nchw, ])?; - let ys = self.visual.run(vec![xs_])?; + let ys = self.visual.run(Xs::from(xs_))?; Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) } @@ -84,7 +84,7 @@ impl Clip { .collect(); let xs = Array2::from_shape_vec((texts.len(), self.context_length), xs)?.into_dyn(); let xs = X::from(xs); - let ys = self.textual.run(vec![xs])?; + let ys = self.textual.run(Xs::from(xs))?; Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) } diff --git a/src/models/db.rs b/src/models/db.rs index 68351fa..f97684c 100644 --- a/src/models/db.rs +++ b/src/models/db.rs @@ -2,7 +2,7 @@ use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{DynConf, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y}; +use crate::{DynConf, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; #[derive(Debug)] pub struct DB { @@ -60,11 +60,11 @@ impl DB { Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys, xs) } - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { let mut ys = Vec::new(); for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let mut y_bbox = Vec::new(); diff --git a/src/models/depth_anything.rs b/src/models/depth_anything.rs index b6ec008..2695a95 100644 --- a/src/models/depth_anything.rs +++ b/src/models/depth_anything.rs @@ -1,4 +1,4 @@ -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, X, Y}; +use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; use anyhow::Result; use image::DynamicImage; use ndarray::Axis; @@ -41,11 +41,11 @@ impl DepthAnything { Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3), Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys, xs) } - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { let mut ys: Vec = Vec::new(); for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); diff --git a/src/models/dinov2.rs b/src/models/dinov2.rs index 9b7b227..226b224 100644 --- a/src/models/dinov2.rs +++ b/src/models/dinov2.rs @@ -1,4 +1,4 @@ -use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y}; +use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; use anyhow::Result; use image::DynamicImage; // use std::path::PathBuf; @@ -63,7 +63,7 @@ impl Dinov2 { ), Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned()))) } diff --git a/src/models/mod.rs b/src/models/mod.rs index 87f7822..da81975 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -6,7 +6,6 @@ mod db; mod depth_anything; mod dinov2; mod modnet; -mod rtdetr; mod rtmo; mod sam; mod svtr; @@ -20,7 +19,6 @@ pub use db::DB; pub use depth_anything::DepthAnything; pub use dinov2::Dinov2; pub use modnet::MODNet; -pub use rtdetr::RTDETR; pub use rtmo::RTMO; pub use sam::{SamKind, SamPrompt, SAM}; pub use svtr::SVTR; diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 5987f09..d606f15 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -2,7 +2,7 @@ use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, X, Y}; +use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; #[derive(Debug)] pub struct MODNet { @@ -42,11 +42,11 @@ impl MODNet { Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys, xs) } - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { let mut ys: Vec = Vec::new(); for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); diff --git a/src/models/rtdetr.rs b/src/models/rtdetr.rs deleted file mode 100644 index 5af4600..0000000 --- a/src/models/rtdetr.rs +++ /dev/null @@ -1,140 +0,0 @@ -use anyhow::Result; -use image::DynamicImage; -use ndarray::{s, Axis}; -use regex::Regex; - -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, X, Y}; - -#[derive(Debug)] -pub struct RTDETR { - engine: OrtEngine, - height: MinOptMax, - width: MinOptMax, - batch: MinOptMax, - confs: DynConf, - nc: usize, - names: Option>, -} - -impl RTDETR { - pub fn new(options: Options) -> Result { - let mut engine = OrtEngine::new(&options)?; - let (batch, height, width) = ( - engine.inputs_minoptmax()[0][0].to_owned(), - engine.inputs_minoptmax()[0][2].to_owned(), - engine.inputs_minoptmax()[0][3].to_owned(), - ); - let names: Option<_> = match options.names { - None => engine.try_fetch("names").map(|names| { - let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); - let mut names_ = vec![]; - for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { - names_.push(name.to_string()); - } - names_ - }), - Some(names) => Some(names.to_owned()), - }; - let nc = options.nc.unwrap_or( - names - .as_ref() - .expect("Failed to get num_classes, make it explicit with `--nc`") - .len(), - ); - let confs = DynConf::new(&options.confs, nc); - engine.dry_run()?; - - Ok(Self { - engine, - confs, - nc, - height, - width, - batch, - names, - }) - } - - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { - let xs_ = X::apply(&[ - Ops::Letterbox( - xs, - self.height() as u32, - self.width() as u32, - "CatmullRom", - 114, - "auto", - false, - ), - Ops::Normalize(0., 255.), - Ops::Nhwc2nchw, - ])?; - let ys = self.engine.run(vec![xs_])?; - self.postprocess(ys, xs) - } - - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { - const CXYWH_OFFSET: usize = 4; // cxcywh - let preds = &xs[0]; - - let mut ys = Vec::new(); - for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() { - // [bs, num_query, 4 + nc] - let width_original = xs0[idx].width() as f32; - let height_original = xs0[idx].height() as f32; - let ratio = - (self.width() as f32 / width_original).min(self.height() as f32 / height_original); - - // save each result - let mut y_bboxes = Vec::new(); - for pred in anchor.axis_iter(Axis(0)) { - let bbox = pred.slice(s![0..CXYWH_OFFSET]); - let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc]); - - // confidence & id - let (id, &confidence) = clss - .into_iter() - .enumerate() - .reduce(|max, x| if x.1 > max.1 { x } else { max }) - .unwrap(); - - // confs filter - if confidence < self.confs[id] { - continue; - } - - // bbox -> input size scale -> rescale - let x = (bbox[0] - bbox[2] / 2.) * self.width() as f32 / ratio; - let y = (bbox[1] - bbox[3] / 2.) * self.height() as f32 / ratio; - let w = bbox[2] * self.width() as f32 / ratio; - let h = bbox[3] * self.height() as f32 / ratio; - y_bboxes.push( - Bbox::default() - .with_xywh( - x.max(0.0f32).min(width_original), - y.max(0.0f32).min(height_original), - w, - h, - ) - .with_confidence(confidence) - .with_id(id as isize) - .with_name(self.names.as_ref().map(|names| names[id].to_owned())), - ) - } - ys.push(Y::default().with_bboxes(&y_bboxes)); - } - Ok(ys) - } - - pub fn batch(&self) -> isize { - self.batch.opt - } - - pub fn width(&self) -> isize { - self.width.opt - } - - pub fn height(&self) -> isize { - self.height.opt - } -} diff --git a/src/models/rtmo.rs b/src/models/rtmo.rs index 3c20090..72b2a8f 100644 --- a/src/models/rtmo.rs +++ b/src/models/rtmo.rs @@ -2,7 +2,7 @@ use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, X, Y}; +use crate::{Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, Xs, X, Y}; #[derive(Debug)] pub struct RTMO { @@ -49,11 +49,11 @@ impl RTMO { false, )? .nhwc2nchw()?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys, xs) } - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { let mut ys: Vec = Vec::new(); let (preds_bboxes, preds_kpts) = if xs[0].ndim() == 3 { (&xs[0], &xs[1]) diff --git a/src/models/sam.rs b/src/models/sam.rs index c95fb26..3eb7e7b 100644 --- a/src/models/sam.rs +++ b/src/models/sam.rs @@ -3,7 +3,7 @@ use image::DynamicImage; use ndarray::{s, Array, Axis}; use rand::prelude::*; -use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y}; +use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; #[derive(Debug, Clone, clap::ValueEnum)] pub enum SamKind { @@ -119,7 +119,7 @@ impl SAM { self.decode(ys, xs, prompts) } - pub fn encode(&mut self, xs: &[DynamicImage]) -> Result> { + pub fn encode(&mut self, xs: &[DynamicImage]) -> Result { let xs_ = X::apply(&[ Ops::Letterbox( xs, @@ -134,12 +134,12 @@ impl SAM { Ops::Nhwc2nchw, ])?; - self.encoder.run(vec![xs_]) + self.encoder.run(Xs::from(xs_)) } pub fn decode( &mut self, - xs: Vec, + xs: Xs, xs0: &[DynamicImage], prompts: &[SamPrompt], ) -> Result> { @@ -213,7 +213,7 @@ impl SAM { } }; - let ys_ = self.decoder.run(args)?; + let ys_ = self.decoder.run(Xs::from(args))?; let mut y_masks: Vec = Vec::new(); let mut y_polygons: Vec = Vec::new(); @@ -223,16 +223,14 @@ impl SAM { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { if !self.use_low_res_mask { (&ys_[0], &ys_[1]) + // (&ys_["masks"], &ys_["iou_predictions"]) } else { (&ys_[2], &ys_[1]) + // (&ys_["low_res_masks"], &ys_["iou_predictions"]) } } SamKind::Sam2 => (&ys_[0], &ys_[1]), - SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) { - (2, 4) => (&ys_[1], &ys_[0]), - (4, 2) => (&ys_[0], &ys_[1]), - _ => anyhow::bail!("Can not parse the outputs of decoder."), - }, + SamKind::EdgeSam => (&ys_["masks"], &ys_["scores"]), }; for (mask, iou) in masks.axis_iter(Axis(0)).zip(confs.axis_iter(Axis(0))) { @@ -251,6 +249,7 @@ impl SAM { continue; } let mask = mask.slice(s![i, .., ..]); + let (h, w) = mask.dim(); let luma = if self.use_low_res_mask { Ops::resize_lumaf32_vec( diff --git a/src/models/svtr.rs b/src/models/svtr.rs index cacc131..7af5023 100644 --- a/src/models/svtr.rs +++ b/src/models/svtr.rs @@ -2,7 +2,7 @@ use anyhow::Result; use image::DynamicImage; use ndarray::Axis; -use crate::{DynConf, MinOptMax, Ops, Options, OrtEngine, X, Y}; +use crate::{DynConf, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; #[derive(Debug)] pub struct SVTR { @@ -57,11 +57,11 @@ impl SVTR { Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys) } - pub fn postprocess(&self, xs: Vec) -> Result> { + pub fn postprocess(&self, xs: Xs) -> Result> { let mut ys: Vec = Vec::new(); for batch in xs[0].axis_iter(Axis(0)) { let preds = batch diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 56d5554..82f3878 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -6,7 +6,7 @@ use regex::Regex; use crate::{ Bbox, BoxType, DynConf, Keypoint, Mask, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Prob, - Vision, YOLOPreds, YOLOTask, YOLOVersion, X, Y, + Vision, Xs, YOLOPreds, YOLOTask, YOLOVersion, X, Y, }; #[derive(Debug)] @@ -158,7 +158,7 @@ impl Vision for YOLO { }) } - fn preprocess(&self, xs: &[Self::Input]) -> Result> { + fn preprocess(&self, xs: &[Self::Input]) -> Result { let xs_ = match self.task { YOLOTask::Classify => { X::resize(xs, self.height() as u32, self.width() as u32, "Bilinear")? @@ -179,14 +179,14 @@ impl Vision for YOLO { Ops::Nhwc2nchw, ])?, }; - Ok(vec![xs_]) + Ok(Xs::from(xs_)) } - fn inference(&mut self, xs: Vec) -> Result> { + fn inference(&mut self, xs: Xs) -> Result { self.engine.run(xs) } - fn postprocess(&self, xs: Vec, xs0: &[Self::Input]) -> Result> { + fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> Result> { let protos = if xs.len() == 2 { Some(&xs[1]) } else { None }; let ys: Vec = xs[0] .axis_iter(Axis(0)) diff --git a/src/models/yolop.rs b/src/models/yolop.rs index 8a1de77..563e9d7 100644 --- a/src/models/yolop.rs +++ b/src/models/yolop.rs @@ -2,7 +2,7 @@ use anyhow::Result; use image::DynamicImage; use ndarray::{s, Array, Axis, IxDyn}; -use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y}; +use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; #[derive(Debug)] pub struct YOLOPv2 { @@ -50,11 +50,12 @@ impl YOLOPv2 { Ops::Normalize(0., 255.), Ops::Nhwc2nchw, ])?; - let ys = self.engine.run(vec![xs_])?; + let ys = self.engine.run(Xs::from(xs_))?; self.postprocess(ys, xs) } - pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + // pub fn postprocess(&self, xs: Vec, xs0: &[DynamicImage]) -> Result> { let mut ys: Vec = Vec::new(); let (xs_da, xs_ll, xs_det) = (&xs[0], &xs[1], &xs[2]); for (idx, ((x_det, x_ll), x_da)) in xs_det diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8dd3b65..34a3078 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Result}; use indicatif::{ProgressBar, ProgressStyle}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; @@ -79,6 +80,14 @@ pub fn download + std::fmt::Debug>( Ok(()) } +pub(crate) fn string_random(n: usize) -> String { + thread_rng() + .sample_iter(&Alphanumeric) + .take(n) + .map(char::from) + .collect() +} + pub(crate) fn string_now(delimiter: &str) -> String { let t_now = chrono::Local::now(); let fmt = format!(