Accelerate model pre-processing and post-processing (#23)

* Add X struct to handle input and preprocessing

*Add Ops struct to manage common operations

* Use SIMD (fast_image_resize) to accelerate model pre-processing and post-processing
This commit is contained in:
Jamjamjon
2024-06-30 15:19:34 +08:00
committed by GitHub
parent 5f6b814090
commit a5141a53be
33 changed files with 822 additions and 528 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.3"
version = "0.0.4"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@ -44,4 +44,4 @@ ab_glyph = "0.2.23"
geo = "0.28.0"
prost = "0.12.4"
human_bytes = "0.4.3"
fast_image_resize = "3.0.4"
fast_image_resize = { git = "https://github.com/jamjamjon/fast_image_resize", branch = "dev" , features = ["image"]}

View File

@ -25,6 +25,8 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
|<img src='examples/yolop/demo.png' height="180px">| <img src='examples/face-parsing/demo.png' height="180px"> | <img src='examples/db/demo.png' height="180px"> |
- 2024/06/30: **Accelerate model pre-processing and post-processing using SIMD**. YOLOv8-seg post-processing (~120ms => ~20ms), Depth-Anything post-processing (~23ms => ~2ms).
## Supported Models
@ -100,7 +102,7 @@ check **[ort guide](https://ort.pyke.io/setup/linking)**
#### 1. Add `usls` as a dependency to your project's `Cargo.toml`
```shell
cargo add --git https://github.com/jamjamjon/usls
usls = { git = "https://github.com/jamjamjon/usls", rev = "xxx"}
```
#### 2. Set `Options` and build model

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -16,7 +16,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// run
let y = model.run(&x)?;
// // annotate
// annotate
let annotator = Annotator::default()
.with_saveout("RTMO")
.with_skeletons(&coco::SKELETONS_16);

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -1,6 +1,6 @@
use usls::{
models::{YOLOVersion, YOLO},
Annotator, DataLoader, Options,
Annotator, DataLoader, Options, Vision,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {

View File

@ -1,6 +1,6 @@
use usls::{
models::{YOLOTask, YOLOVersion, YOLO},
Annotator, DataLoader, Options,
Annotator, DataLoader, Options, Vision,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model

View File

@ -1,4 +1,4 @@
use usls::{models::YOLO, Annotator, DataLoader, Options};
use usls::{models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1.build model

View File

@ -1,13 +1,13 @@
use usls::{coco, models::YOLO, Annotator, DataLoader, Options};
use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
let options = Options::default()
.with_model("yolov8m-dyn.onnx")?
// .with_model("yolov8m-dyn.onnx")?
// .with_model("yolov8m-dyn-f16.onnx")?
// .with_model("yolov8m-pose-dyn.onnx")?
// .with_model("yolov8m-cls-dyn.onnx")?
// .with_model("yolov8m-seg-dyn.onnx")?
.with_model("yolov8m-seg-dyn.onnx")?
// .with_model("yolov8m-obb-dyn.onnx")?
// .with_model("yolov8m-oiv7-dyn.onnx")?
// .with_trt(0)
@ -37,6 +37,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// run & annotate
for (xs, _paths) in dl {
let ys = model.run(&xs)?;
// let ys = model.forward(&xs, true)?;
annotator.annotate(&xs, &ys);
}

View File

@ -1,6 +1,6 @@
use usls::{
models::{YOLOVersion, YOLO},
Annotator, DataLoader, Options,
Annotator, DataLoader, Options, Vision,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {

View File

@ -92,10 +92,12 @@ impl DataLoader {
}
pub fn try_read<P: AsRef<Path>>(path: P) -> Result<DynamicImage> {
image::io::Reader::open(&path)
let img = image::io::Reader::open(&path)
.map_err(|_| anyhow!("Failed to open image at {:?}", path.as_ref()))?
.decode()
.map_err(|_| anyhow!("Failed to decode image at {:?}", path.as_ref()))
.map_err(|_| anyhow!("Failed to decode image at {:?}", path.as_ref()))?
.into_rgb8();
Ok(DynamicImage::from(img))
}
pub fn with_batch(mut self, x: usize) -> Self {

View File

@ -9,9 +9,7 @@ use ort::{
use prost::Message;
use std::collections::HashSet;
use crate::{
home_dir, onnx, ops::make_divisible, Device, MinOptMax, Options, Ts, CHECK_MARK, CROSS_MARK,
};
use crate::{home_dir, onnx, Device, MinOptMax, Ops, Options, Ts, CHECK_MARK, CROSS_MARK, X};
/// Ort Tensor Attrs: name, data_type, dims
#[derive(Debug)]
@ -57,7 +55,7 @@ impl OrtEngine {
params += param;
// mems
let param = make_divisible(param, byte_alignment);
let param = Ops::make_divisible(param, byte_alignment);
let n = Self::nbytes_from_onnx_dtype_id(tensor_proto.data_type as usize);
let wbmem = param * n;
wbmems += wbmem;
@ -258,17 +256,18 @@ impl OrtEngine {
pub fn dry_run(&mut self) -> Result<()> {
if self.num_dry_run > 0 {
let mut xs: Vec<Array<f32, IxDyn>> = Vec::new();
let mut xs = Vec::new();
for i in self.inputs_minoptmax.iter() {
let mut x: Vec<usize> = Vec::new();
for i_ in i.iter() {
x.push(i_.opt as usize);
}
let x: Array<f32, IxDyn> = Array::ones(x).into_dyn();
xs.push(x);
xs.push(X::from(x));
}
for _ in 0..self.num_dry_run {
self.run(xs.as_ref())?;
// self.run(xs.as_ref())?;
self.run(xs.clone())?;
}
self.ts.clear();
println!("{CHECK_MARK} Dryrun x{}", self.num_dry_run);
@ -276,7 +275,7 @@ impl OrtEngine {
Ok(())
}
pub fn run(&mut self, xs: &[Array<f32, IxDyn>]) -> Result<Vec<Array<f32, IxDyn>>> {
pub fn run(&mut self, xs: Vec<X>) -> Result<Vec<X>> {
// inputs dtype alignment
let mut xs_ = Vec::new();
let t_pre = std::time::Instant::now();
@ -330,7 +329,8 @@ impl OrtEngine {
.into_owned(),
_ => todo!(),
};
ys.push(y_);
// ys.push(y_);
ys.push(X::from(y_));
}
let t_post = t_post.elapsed();
self.ts.add_or_push(2, t_post);

View File

@ -11,6 +11,8 @@ pub mod ops;
mod options;
mod tokenizer_stream;
mod ts;
mod vision;
mod x;
pub use annotator::Annotator;
pub use dataloader::DataLoader;
@ -20,6 +22,9 @@ pub use engine::OrtEngine;
pub use logits_sampler::LogitsSampler;
pub use metric::Metric;
pub use min_opt_max::MinOptMax;
pub use ops::Ops;
pub use options::Options;
pub use tokenizer_stream::TokenizerStream;
pub use ts::Ts;
pub use vision::Vision;
pub use x::X;

View File

@ -1,230 +1,292 @@
use anyhow::Result;
use fast_image_resize as fr;
use image::{DynamicImage, GenericImageView, ImageBuffer};
use fast_image_resize as fir;
use fast_image_resize::{
images::{CroppedImageMut, Image},
pixels::PixelType,
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
use ndarray::{s, Array, Axis, IxDyn};
use rayon::prelude::*;
pub fn standardize(xs: Array<f32, IxDyn>, mean: &[f32], std: &[f32]) -> Array<f32, IxDyn> {
let mean = Array::from_shape_vec((1, mean.len(), 1, 1), mean.to_vec()).unwrap();
let std = Array::from_shape_vec((1, std.len(), 1, 1), std.to_vec()).unwrap();
(xs - mean) / std
use crate::X;
pub enum Ops<'a> {
Resize(&'a [DynamicImage], u32, u32, &'a str),
Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool),
Normalize(f32, f32),
Standardize(&'a [f32], &'a [f32], usize),
Permute(&'a [usize]),
InsertAxis(usize),
Nhwc2nchw,
Nchw2nhwc,
Norm,
}
pub fn normalize(xs: Array<f32, IxDyn>, min_: f32, max_: f32) -> Array<f32, IxDyn> {
(xs - min_) / (max_ - min_)
}
impl Ops<'_> {
pub fn apply(ops: &[Self]) -> Result<X> {
let mut y = X::default();
pub fn norm2(xs: &Array<f32, IxDyn>) -> Array<f32, IxDyn> {
let std_ = xs
.mapv(|x| x * x)
.sum_axis(Axis(1))
.mapv(f32::sqrt)
.insert_axis(Axis(1));
xs / std_
}
pub fn scale_wh(w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
let r = (w1 / w0).min(h1 / h0);
(r, (w0 * r).round(), (h0 * r).round())
}
pub fn build_resizer(ty: &str) -> fr::Resizer {
let ty = match ty {
"box" => fr::FilterType::Box,
"bilinear" => fr::FilterType::Bilinear,
"hamming" => fr::FilterType::Hamming,
"catmullRom" => fr::FilterType::CatmullRom,
"mitchell" => fr::FilterType::Mitchell,
"lanczos3" => fr::FilterType::Lanczos3,
_ => todo!(),
};
fr::Resizer::new(fr::ResizeAlg::Convolution(ty))
}
pub fn resize(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
let mut resizer = build_resizer(filter);
for (idx, x) in xs.iter().enumerate() {
// src
let src_image = fr::Image::from_vec_u8(
std::num::NonZeroU32::new(x.width()).unwrap(),
std::num::NonZeroU32::new(x.height()).unwrap(),
x.to_rgb8().into_raw(),
fr::PixelType::U8x3,
)
.unwrap();
// dst
let mut dst_image = fr::Image::new(
std::num::NonZeroU32::new(width).unwrap(),
std::num::NonZeroU32::new(height).unwrap(),
src_image.pixel_type(),
);
// resize
resizer
.resize(&src_image.view(), &mut dst_image.view_mut())
.unwrap();
let buffer = dst_image.into_vec();
// to ndarray
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
.unwrap()
.mapv(|x| x as f32)
.permuted_axes([2, 0, 1]);
let mut data = ys.slice_mut(s![idx, .., .., ..]);
data.assign(&y_);
for op in ops {
y = match op {
Self::Resize(xs, h, w, filter) => X::resize(xs, *h, *w, filter)?,
Self::Letterbox(xs, h, w, filter, bg, resize_by, center) => {
X::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)?
}
Self::Normalize(min_, max_) => y.normalize(*min_, *max_)?,
Self::Standardize(mean, std, d) => y.standardize(mean, std, *d)?,
Self::Permute(shape) => y.permute(shape)?,
Self::InsertAxis(d) => y.insert_axis(*d)?,
Self::Nhwc2nchw => y.nhwc2nchw()?,
Self::Nchw2nhwc => y.nchw2nhwc()?,
_ => todo!(),
}
}
Ok(y)
}
Ok(ys)
}
pub fn letterbox(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
bg: Option<u8>,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
let mut resizer = build_resizer(filter);
for (idx, x) in xs.iter().enumerate() {
let (w0, h0) = x.dimensions();
let (_, w_new, h_new) = scale_wh(w0 as f32, h0 as f32, width as f32, height as f32);
pub fn normalize(x: Array<f32, IxDyn>, min: f32, max: f32) -> Result<Array<f32, IxDyn>> {
if min > max {
anyhow::bail!("Input `min` is greater than `max`");
}
Ok((x - min) / (max - min))
}
// src
let src_image = fr::Image::from_vec_u8(
std::num::NonZeroU32::new(w0).unwrap(),
std::num::NonZeroU32::new(h0).unwrap(),
x.to_rgb8().into_raw(),
fr::PixelType::U8x3,
)
.unwrap();
pub fn standardize(
x: Array<f32, IxDyn>,
mean: &[f32],
std: &[f32],
dim: usize,
) -> Result<Array<f32, IxDyn>> {
if mean.len() != std.len() {
anyhow::bail!("The lengths of mean and std are not equal.");
}
let shape = x.shape();
if dim >= shape.len() || shape[dim] != mean.len() {
anyhow::bail!("The specified dimension or mean/std length is inconsistent with the input dimensions.");
}
let mut shape = vec![1; shape.len()];
shape[dim] = mean.len();
let mean = Array::from_shape_vec(shape.clone(), mean.to_vec())?;
let std = Array::from_shape_vec(shape, std.to_vec())?;
Ok((x - mean) / std)
}
// dst
let mut dst_image = match bg {
Some(bg) => fr::Image::from_vec_u8(
std::num::NonZeroU32::new(width).unwrap(),
std::num::NonZeroU32::new(height).unwrap(),
vec![bg; 3 * height as usize * width as usize],
src_image.pixel_type(),
)
.unwrap(),
None => fr::Image::new(
std::num::NonZeroU32::new(width).unwrap(),
std::num::NonZeroU32::new(height).unwrap(),
src_image.pixel_type(),
),
pub fn permute(x: Array<f32, IxDyn>, shape: &[usize]) -> Result<Array<f32, IxDyn>> {
if shape.len() != x.shape().len() {
anyhow::bail!(
"Shape inconsistent. Target: {:?}, {}, got: {:?}, {}",
x.shape(),
x.shape().len(),
shape,
shape.len()
);
}
Ok(x.permuted_axes(shape.to_vec()).into_dyn())
}
pub fn nhwc2nchw(x: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>> {
Self::permute(x, &[0, 3, 1, 2])
}
pub fn nchw2nhwc(x: Array<f32, IxDyn>) -> Result<Array<f32, IxDyn>> {
Self::permute(x, &[0, 2, 3, 1])
}
pub fn insert_axis(x: Array<f32, IxDyn>, d: usize) -> Result<Array<f32, IxDyn>> {
if x.shape().len() < d {
anyhow::bail!(
"The specified axis insertion position {} exceeds the shape's maximum limit of {}.",
d,
x.shape().len()
);
}
Ok(x.insert_axis(Axis(d)))
}
pub fn norm(xs: Array<f32, IxDyn>, d: usize) -> Result<Array<f32, IxDyn>> {
if xs.shape().len() < d {
anyhow::bail!(
"The specified axis {} exceeds the shape's maximum limit of {}.",
d,
xs.shape().len()
);
}
let std_ = xs
.mapv(|x| x * x)
.sum_axis(Axis(d))
.mapv(f32::sqrt)
.insert_axis(Axis(d));
Ok(xs / std_)
}
pub fn scale_wh(w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
let r = (w1 / w0).min(h1 / h0);
(r, (w0 * r).round(), (h0 * r).round())
}
pub fn make_divisible(x: usize, divisor: usize) -> usize {
(x + divisor - 1) / divisor * divisor
}
// deprecated
pub fn descale_mask(mask: DynamicImage, w0: f32, h0: f32, w1: f32, h1: f32) -> DynamicImage {
// 0 -> 1
let (_, w, h) = Ops::scale_wh(w1, h1, w0, h0);
let mut mask = mask.to_owned();
let mask = mask.crop(0, 0, w as u32, h as u32);
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
}
pub fn resize_lumaf32_vec(
v: &[f32],
w0: f32,
h0: f32,
w1: f32,
h1: f32,
crop_src: bool,
filter: &str,
) -> Result<Vec<u8>> {
let src_mask = fir::images::Image::from_vec_u8(
w0 as _,
h0 as _,
v.iter().flat_map(|x| x.to_le_bytes()).collect(),
fir::PixelType::F32,
)?;
let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type());
let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
options = options.crop(0., 0., w.into(), h.into());
};
resizer.resize(&src_mask, &mut dst_mask, &options)?;
// mutable view
let mut dst_view = dst_image
.view_mut()
.crop(
0,
0,
std::num::NonZeroU32::new(w_new as u32).unwrap(),
std::num::NonZeroU32::new(h_new as u32).unwrap(),
)
.unwrap();
// u8*2 -> f32
let mask_f32: Vec<f32> = dst_mask
.into_vec()
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
// resize
resizer.resize(&src_image.view(), &mut dst_view).unwrap();
let buffer = dst_image.into_vec();
// to ndarray
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
.unwrap()
.mapv(|x| x as f32)
.permuted_axes([2, 0, 1]);
let mut data = ys.slice_mut(s![idx, .., .., ..]);
data.assign(&y_);
// f32 -> u8
let v: Vec<u8> = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect();
Ok(v)
}
Ok(ys)
}
pub fn resize_with_fixed_height(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
bg: Option<u8>,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), 3, height as usize, width as usize)).into_dyn();
let mut resizer = build_resizer(filter);
for (idx, x) in xs.iter().enumerate() {
let (w0, h0) = x.dimensions();
let h_new = height;
let w_new = height * w0 / h0;
// src
let src_image = fr::Image::from_vec_u8(
std::num::NonZeroU32::new(w0).unwrap(),
std::num::NonZeroU32::new(h0).unwrap(),
x.to_rgb8().into_raw(),
fr::PixelType::U8x3,
)
.unwrap();
// dst
let mut dst_image = match bg {
Some(bg) => fr::Image::from_vec_u8(
std::num::NonZeroU32::new(width).unwrap(),
std::num::NonZeroU32::new(height).unwrap(),
vec![bg; 3 * height as usize * width as usize],
src_image.pixel_type(),
)
.unwrap(),
None => fr::Image::new(
std::num::NonZeroU32::new(width).unwrap(),
std::num::NonZeroU32::new(height).unwrap(),
src_image.pixel_type(),
),
pub fn resize_luma8_vec(
v: &[u8],
w0: f32,
h0: f32,
w1: f32,
h1: f32,
crop_src: bool,
filter: &str,
) -> Result<Vec<u8>> {
let src_mask =
fir::images::Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), fir::PixelType::U8)?;
let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type());
let (mut resizer, mut options) = Self::build_resizer_filter(filter)?;
if crop_src {
let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _);
options = options.crop(0., 0., w.into(), h.into());
};
// mutable view
let mut dst_view = dst_image
.view_mut()
.crop(
0,
0,
std::num::NonZeroU32::new(w_new).unwrap(),
std::num::NonZeroU32::new(h_new).unwrap(),
)
.unwrap();
// resize
resizer.resize(&src_image.view(), &mut dst_view).unwrap();
let buffer = dst_image.into_vec();
// to ndarray
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)
.unwrap()
.mapv(|x| x as f32)
.permuted_axes([2, 0, 1]);
let mut data = ys.slice_mut(s![idx, .., .., ..]);
data.assign(&y_);
resizer.resize(&src_mask, &mut dst_mask, &options)?;
Ok(dst_mask.into_vec())
}
Ok(ys)
}
pub fn build_dyn_image_from_raw(v: Vec<f32>, height: u32, width: u32) -> DynamicImage {
let v: ImageBuffer<image::Luma<_>, Vec<f32>> =
ImageBuffer::from_raw(width, height, v).expect("Faild to create image from ndarray");
image::DynamicImage::from(v)
}
pub fn build_resizer_filter(ty: &str) -> Result<(Resizer, ResizeOptions)> {
let ty = match ty {
"Box" => FilterType::Box,
"Bilinear" => FilterType::Bilinear,
"Hamming" => FilterType::Hamming,
"CatmullRom" => FilterType::CatmullRom,
"Mitchell" => FilterType::Mitchell,
"Gaussian" => FilterType::Gaussian,
"Lanczos3" => FilterType::Lanczos3,
_ => anyhow::bail!("Unsupported resize filter type: {ty}"),
};
Ok((
Resizer::new(),
ResizeOptions::new().resize_alg(ResizeAlg::Convolution(ty)),
))
}
pub fn descale_mask(mask: DynamicImage, w0: f32, h0: f32, w1: f32, h1: f32) -> DynamicImage {
// 0 -> 1
let (_, w, h) = scale_wh(w1, h1, w0, h0);
let mut mask = mask.to_owned();
let mask = mask.crop(0, 0, w as u32, h as u32);
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
}
pub fn resize(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn();
let (mut resizer, options) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
let buffer = if x.dimensions() == (width, height) {
x.to_rgba8().into_raw()
} else {
let mut dst_image = Image::new(width, height, PixelType::U8x3);
resizer.resize(x, &mut dst_image, &options)?;
dst_image.into_vec()
};
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)?
.mapv(|x| x as f32);
ys.slice_mut(s![idx, .., .., ..]).assign(&y_);
}
Ok(ys)
}
pub fn make_divisible(x: usize, divisor: usize) -> usize {
(x + divisor - 1) / divisor * divisor
pub fn letterbox(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
bg: u8,
resize_by: &str,
center: bool,
) -> Result<Array<f32, IxDyn>> {
let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn();
let (mut resizer, options) = Self::build_resizer_filter(filter)?;
for (idx, x) in xs.iter().enumerate() {
let (w0, h0) = x.dimensions();
let buffer = if w0 == width && h0 == height {
x.to_rgba8().into_raw()
} else {
let (w, h) = match resize_by {
"auto" => {
let r = (width as f32 / w0 as f32).min(height as f32 / h0 as f32);
(
(w0 as f32 * r).round() as u32,
(h0 as f32 * r).round() as u32,
)
}
"height" => (height * w0 / h0, height),
"width" => (width, width * h0 / w0),
_ => anyhow::bail!("Option: width, height, auto"),
};
let mut dst_image = Image::from_vec_u8(
width,
height,
vec![bg; 3 * height as usize * width as usize],
PixelType::U8x3,
)?;
let (l, t) = if center {
if w == width {
(0, (height - h) / 2)
} else {
((width - w) / 2, 0)
}
} else {
(0, 0)
};
let mut cropped_dst_image = CroppedImageMut::new(&mut dst_image, l, t, w, h)?;
resizer.resize(x, &mut cropped_dst_image, &options)?;
dst_image.into_vec()
};
let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)?
.mapv(|x| x as f32);
ys.slice_mut(s![idx, .., .., ..]).assign(&y_);
}
Ok(ys)
}
}

View File

@ -30,7 +30,6 @@ impl Ts {
&self.ts
}
// TODO: overhead?
pub fn add_or_push(&mut self, i: usize, x: Duration) {
match self.ts.get_mut(i) {
Some(elem) => *elem += x,

46
src/core/vision.rs Normal file
View File

@ -0,0 +1,46 @@
use crate::{Options, X, Y};
pub trait Vision: Sized {
type Input; // DynamicImage
/// Creates a new instance of the model with the given options.
fn new(options: Options) -> anyhow::Result<Self>;
/// Preprocesses the input data.
fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result<Vec<X>>;
/// Executes the model on the preprocessed data.
fn inference(&mut self, xs: Vec<X>) -> anyhow::Result<Vec<X>>;
/// Postprocesses the model's output.
fn postprocess(&self, xs: Vec<X>, xs0: &[Self::Input]) -> anyhow::Result<Vec<Y>>;
/// Executes the full pipeline.
fn run(&mut self, xs: &[Self::Input]) -> anyhow::Result<Vec<Y>> {
let ys = self.preprocess(xs)?;
let ys = self.inference(ys)?;
let ys = self.postprocess(ys, xs)?;
Ok(ys)
}
/// Executes the full pipeline.
fn forward(&mut self, xs: &[Self::Input], profile: bool) -> anyhow::Result<Vec<Y>> {
let t_pre = std::time::Instant::now();
let ys = self.preprocess(xs)?;
let t_pre = t_pre.elapsed();
let t_exe = std::time::Instant::now();
let ys = self.inference(ys)?;
let t_exe = t_exe.elapsed();
let t_post = std::time::Instant::now();
let ys = self.postprocess(ys, xs)?;
let t_post = t_post.elapsed();
if profile {
println!("> Preprocess: {t_pre:?} | Execution: {t_exe:?} | Postprocess: {t_post:?}");
}
Ok(ys)
}
}

89
src/core/x.rs Normal file
View File

@ -0,0 +1,89 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Dim, IxDyn, IxDynImpl};
use crate::Ops;
#[derive(Debug, Clone, Default)]
pub struct X(pub Array<f32, IxDyn>);
impl From<Array<f32, IxDyn>> for X {
fn from(x: Array<f32, IxDyn>) -> Self {
Self(x)
}
}
impl std::ops::Deref for X {
type Target = Array<f32, IxDyn>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl X {
pub fn zeros(shape: &[usize]) -> Self {
Self(Array::zeros(Dim(IxDynImpl::from(shape.to_vec()))))
}
pub fn apply(ops: &[Ops]) -> Result<Self> {
Ops::apply(ops)
}
pub fn permute(mut self, shape: &[usize]) -> Result<Self> {
self.0 = Ops::permute(self.0, shape)?;
Ok(self)
}
pub fn nhwc2nchw(mut self) -> Result<Self> {
self.0 = Ops::nhwc2nchw(self.0)?;
Ok(self)
}
pub fn nchw2nhwc(mut self) -> Result<Self> {
self.0 = Ops::nchw2nhwc(self.0)?;
Ok(self)
}
pub fn insert_axis(mut self, d: usize) -> Result<Self> {
self.0 = Ops::insert_axis(self.0, d)?;
Ok(self)
}
pub fn dims(&self) -> &[usize] {
self.0.shape()
}
pub fn normalize(mut self, min_: f32, max_: f32) -> Result<Self> {
self.0 = Ops::normalize(self.0, min_, max_)?;
Ok(self)
}
pub fn standardize(mut self, mean: &[f32], std: &[f32], dim: usize) -> Result<Self> {
self.0 = Ops::standardize(self.0, mean, std, dim)?;
Ok(self)
}
pub fn norm(mut self, d: usize) -> Result<Self> {
self.0 = Ops::norm(self.0, d)?;
Ok(self)
}
pub fn resize(xs: &[DynamicImage], height: u32, width: u32, filter: &str) -> Result<Self> {
Ok(Self(Ops::resize(xs, height, width, filter)?))
}
pub fn letterbox(
xs: &[DynamicImage],
height: u32,
width: u32,
filter: &str,
bg: u8,
resize_by: &str,
center: bool,
) -> Result<Self> {
Ok(Self(Ops::letterbox(
xs, height, width, filter, bg, resize_by, center,
)?))
}
}

View File

@ -4,7 +4,7 @@ use ndarray::{s, Array, Axis, IxDyn};
use std::io::Write;
use tokenizers::Tokenizer;
use crate::{ops, Embedding, LogitsSampler, MinOptMax, Options, OrtEngine, TokenizerStream, Y};
use crate::{Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, X, Y};
#[derive(Debug)]
pub struct Blip {
@ -43,20 +43,23 @@ impl Blip {
}
pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result<Y> {
let xs_ = ops::resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"bilinear",
)?;
let xs_ = ops::normalize(xs_, 0., 255.);
let xs_ = ops::standardize(
xs_,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
);
let ys: Vec<Array<f32, IxDyn>> = self.visual.run(&[xs_])?;
Ok(Y::default().with_embedding(Embedding::new(ys[0].to_owned())))
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Bilinear",
),
Ops::Normalize(0., 255.),
Ops::Standardize(
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
3,
),
Ops::Nhwc2nchw,
])?;
let ys = self.visual.run(vec![xs_])?;
Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned())))
}
pub fn caption(
@ -100,13 +103,16 @@ impl Blip {
loop {
let input_ids_nd: Array<f32, IxDyn> = Array::from_vec(input_ids.to_owned()).into_dyn();
let input_ids_nd = input_ids_nd.insert_axis(Axis(0));
let input_ids_nd = X::from(input_ids_nd);
let input_ids_attn_mask: Array<f32, IxDyn> =
Array::ones(input_ids_nd.shape()).into_dyn();
let y = self.textual.run(&[
let input_ids_attn_mask = X::from(input_ids_attn_mask);
let y = self.textual.run(vec![
input_ids_nd,
input_ids_attn_mask,
image_embeds.data().to_owned(),
image_embeds_attn_mask.to_owned(),
X::from(image_embeds.data().to_owned()),
X::from(image_embeds_attn_mask.to_owned()),
])?; // N, length, vocab_size
let y = y[0].slice(s!(0, -1.., ..));
let logits = y.slice(s!(0, ..)).to_vec();

View File

@ -1,9 +1,10 @@
use crate::{ops, Embedding, MinOptMax, Options, OrtEngine, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Array2, IxDyn};
use ndarray::Array2;
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y};
#[derive(Debug)]
pub struct Clip {
pub textual: OrtEngine,
@ -53,20 +54,23 @@ impl Clip {
}
pub fn encode_images(&mut self, xs: &[DynamicImage]) -> Result<Y> {
let xs_ = ops::resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"bilinear",
)?;
let xs_ = ops::normalize(xs_, 0., 255.);
let xs_ = ops::standardize(
xs_,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
);
let ys: Vec<Array<f32, IxDyn>> = self.visual.run(&[xs_])?;
Ok(Y::default().with_embedding(Embedding::new(ys[0].to_owned())))
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Bilinear",
),
Ops::Normalize(0., 255.),
Ops::Standardize(
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
3,
),
Ops::Nhwc2nchw,
])?;
let ys = self.visual.run(vec![xs_])?;
Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned())))
}
pub fn encode_texts(&mut self, texts: &[String]) -> Result<Y> {
@ -79,8 +83,9 @@ impl Clip {
.flat_map(|i| i.get_ids().iter().map(|&b| b as f32))
.collect();
let xs = Array2::from_shape_vec((texts.len(), self.context_length), xs)?.into_dyn();
let ys = self.textual.run(&[xs])?;
Ok(Y::default().with_embedding(Embedding::new(ys[0].to_owned())))
let xs = X::from(xs);
let ys = self.textual.run(vec![xs])?;
Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned())))
}
pub fn batch_visual(&self) -> usize {

View File

@ -1,7 +1,8 @@
use crate::{ops, DynConf, Mbr, MinOptMax, Options, OrtEngine, Polygon, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Axis, IxDyn};
use ndarray::Axis;
use crate::{DynConf, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
#[derive(Debug)]
pub struct DB {
@ -45,54 +46,66 @@ impl DB {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::letterbox(
xs,
self.height.opt as u32,
self.width.opt as u32,
"bilinear",
Some(114),
)?;
let xs_ = ops::normalize(xs_, 0., 255.);
let xs_ = ops::standardize(xs_, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]);
let ys = self.engine.run(&[xs_])?;
let xs_ = X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"Bilinear",
114,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys = Vec::new();
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let mut y_bbox = Vec::new();
let mut y_polygons: Vec<Polygon> = Vec::new();
let mut y_mbrs: Vec<Mbr> = Vec::new();
// reshape
let h = luma.dim()[1];
let w = luma.dim()[2];
let luma = luma.into_shape((h, w, 1))?.into_owned();
// build image from ndarray
let v = luma
.into_raw_vec()
.iter()
.map(|x| if x <= &self.binary_thresh { 0.0 } else { *x })
.collect::<Vec<_>>();
let mut mask_im =
ops::build_dyn_image_from_raw(v, self.height() as u32, self.width() as u32);
// input image
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
// rescale mask image
let (ratio, w_mask, h_mask) =
ops::scale_wh(image_width, image_height, w as f32, h as f32);
let mask_im = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
let mask_im = mask_im.resize_exact(
image_width as u32,
image_height as u32,
image::imageops::FilterType::Triangle,
);
let mask_im = mask_im.into_luma8();
// reshape
let h = luma.dim()[1];
let w = luma.dim()[2];
let (ratio, _, _) = Ops::scale_wh(image_width, image_height, w as f32, h as f32);
let v = luma
.into_owned()
.into_raw_vec()
.iter()
.map(|x| {
if x <= &self.binary_thresh {
0u8
} else {
(*x * 255.0) as u8
}
})
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_vec(
&v,
self.width() as _,
self.height() as _,
image_width as _,
image_height as _,
true,
"Bilinear",
)?;
let mask_im: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(image_width as _, image_height as _, luma) {
None => continue,
Some(x) => x,
};
// contours
let contours: Vec<imageproc::contours::Contour<i32>> =
@ -105,14 +118,18 @@ impl DB {
{
continue;
}
let mask = Polygon::default().with_points_imageproc(&contour.points);
let delta = mask.area() * ratio.round() as f64 * self.unclip_ratio as f64
/ mask.perimeter();
// TODO: optimize
let mask = mask
.unclip(delta, image_width as f64, image_height as f64)
.resample(50)
// .simplify(6e-4)
.convex_hull();
if let Some(bbox) = mask.bbox() {
if bbox.height() < self.min_height || bbox.width() < self.min_width {
continue;
@ -131,6 +148,7 @@ impl DB {
continue;
}
}
ys.push(
Y::default()
.with_bboxes(&y_bbox)

View File

@ -1,7 +1,7 @@
use crate::{ops, Mask, MinOptMax, Options, OrtEngine, Y};
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, X, Y};
use anyhow::Result;
use image::{DynamicImage, ImageBuffer};
use ndarray::{Array, Axis, IxDyn};
use image::DynamicImage;
use ndarray::Axis;
#[derive(Debug)]
pub struct DepthAnything {
@ -30,41 +30,50 @@ impl DepthAnything {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"lanczos3",
)?;
let xs_ = ops::normalize(xs_, 0.0, 255.0);
let xs_ = ops::standardize(xs_, &[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225]);
let ys = self.engine.run(&[xs_])?;
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Lanczos3",
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let luma = luma
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = luma.into_raw_vec();
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let v = luma.into_owned().into_raw_vec();
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
let v = v
.iter()
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();
let luma: ImageBuffer<image::Luma<_>, Vec<u8>> =
ImageBuffer::from_raw(self.width() as u32, self.height() as u32, v)
.expect("Faild to create image from ndarray");
let luma = image::DynamicImage::from(luma);
let luma = luma.resize_exact(
xs0[idx].width(),
xs0[idx].height(),
image::imageops::FilterType::CatmullRom,
let luma = Ops::resize_luma8_vec(
&v,
self.width() as _,
self.height() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(
Y::default().with_masks(&[Mask::default().with_mask(DynamicImage::from(luma))]),
);
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)
}

View File

@ -1,7 +1,6 @@
use crate::{ops, Embedding, MinOptMax, Options, OrtEngine, Y};
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, IxDyn};
// use std::path::PathBuf;
// use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
@ -49,20 +48,23 @@ impl Dinov2 {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Y> {
let xs_ = ops::resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"lanczos3",
)?;
let xs_ = ops::normalize(xs_, 0., 255.);
let xs_ = ops::standardize(
xs_,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
);
let ys: Vec<Array<f32, IxDyn>> = self.engine.run(&[xs_])?;
Ok(Y::default().with_embedding(Embedding::new(ys[0].to_owned())))
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Lanczos3",
),
Ops::Normalize(0., 255.),
Ops::Standardize(
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.2613026, 0.2757771],
3,
),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
Ok(Y::default().with_embedding(Embedding::from(ys[0].to_owned())))
}
// pub fn build_index(&self, metric: Metric) -> Result<usearch::Index> {

View File

@ -1,8 +1,8 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Axis, IxDyn};
use ndarray::Axis;
use crate::{ops, Mask, MinOptMax, Options, OrtEngine, Y};
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, X, Y};
#[derive(Debug)]
pub struct MODNet {
@ -31,37 +31,42 @@ impl MODNet {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"lanczos3",
)?;
let xs_ = ops::normalize(xs_, 127.5, 255.);
let ys = self.engine.run(&[xs_])?;
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Lanczos3",
),
Ops::Normalize(0., 255.),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let luma = luma
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = luma
.into_raw_vec()
.iter()
.map(|x| (x * 255.0) as u8)
.collect::<Vec<_>>();
let luma: image::ImageBuffer<image::Luma<_>, Vec<u8>> =
image::ImageBuffer::from_raw(self.width() as u32, self.height() as u32, v)
.expect("Faild to create image from ndarray");
let luma = image::DynamicImage::from(luma);
let luma = luma.resize_exact(
xs0[idx].width(),
xs0[idx].height(),
image::imageops::FilterType::CatmullRom,
);
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let luma = luma.mapv(|x| (x * 255.0) as u8);
let luma = Ops::resize_luma8_vec(
&luma.into_raw_vec(),
self.width() as _,
self.height() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
let luma = DynamicImage::from(luma);
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)

View File

@ -1,9 +1,9 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis, IxDyn};
use ndarray::{s, Axis};
use regex::Regex;
use crate::{ops, Bbox, DynConf, MinOptMax, Options, OrtEngine, Y};
use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, X, Y};
#[derive(Debug)]
pub struct RTDETR {
@ -56,19 +56,24 @@ impl RTDETR {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::letterbox(
xs,
self.height() as u32,
self.width() as u32,
"catmullRom",
Some(114),
)?;
let xs_ = ops::normalize(xs_, 0.0, 255.0);
let ys = self.engine.run(&[xs_])?;
let xs_ = X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"CatmullRom",
114,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
const CXYWH_OFFSET: usize = 4; // cxcywh
let preds = &xs[0];

View File

@ -1,8 +1,8 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Axis, IxDyn};
use ndarray::Axis;
use crate::{ops, Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, Y};
use crate::{Bbox, DynConf, Keypoint, MinOptMax, Options, OrtEngine, X, Y};
#[derive(Debug)]
pub struct RTMO {
@ -39,18 +39,21 @@ impl RTMO {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::letterbox(
let xs_ = X::letterbox(
xs,
self.height() as u32,
self.width() as u32,
"catmullRom",
Some(114),
)?;
let ys = self.engine.run(&[xs_])?;
"CatmullRom",
114,
"auto",
false,
)?
.nhwc2nchw()?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
let (preds_bboxes, preds_kpts) = if xs[0].ndim() == 3 {
(&xs[0], &xs[1])
@ -80,6 +83,7 @@ impl RTMO {
let x2 = xyxyc[2] / ratio;
let y2 = xyxyc[3] / ratio;
let confidence = xyxyc[4];
if confidence < self.confs[0] {
continue;
}

View File

@ -1,8 +1,8 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Axis, IxDyn};
use ndarray::Axis;
use crate::{ops, DynConf, MinOptMax, Options, OrtEngine, Y};
use crate::{DynConf, MinOptMax, Ops, Options, OrtEngine, X, Y};
#[derive(Debug)]
pub struct SVTR {
@ -43,22 +43,27 @@ impl SVTR {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::resize_with_fixed_height(
xs,
self.height.opt as u32,
self.width.opt as u32,
"bilinear",
Some(0),
)?;
let xs_ = ops::normalize(xs_, 0.0, 255.0);
let ys: Vec<Array<f32, IxDyn>> = self.engine.run(&[xs_])?;
let ys = ys[0].to_owned();
self.postprocess(&ys)
let xs_ = X::apply(&[
Ops::Letterbox(
xs,
self.height.opt as u32,
self.width.opt as u32,
"Bilinear",
0,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys)
}
pub fn postprocess(&self, output: &Array<f32, IxDyn>) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for batch in output.axis_iter(Axis(0)) {
for batch in xs[0].axis_iter(Axis(0)) {
let preds = batch
.axis_iter(Axis(0))
.filter_map(|x| {
@ -91,7 +96,6 @@ impl SVTR {
ys.push(Y::default().with_texts(&[text]))
}
Ok(ys)
}
}

View File

@ -1,10 +1,12 @@
use anyhow::Result;
use clap::ValueEnum;
use image::{DynamicImage, ImageBuffer};
use ndarray::{s, Array, Axis, IxDyn};
use image::DynamicImage;
use ndarray::{s, Array, Axis};
use regex::Regex;
use crate::{ops, Bbox, DynConf, Keypoint, Mbr, MinOptMax, Options, OrtEngine, Polygon, Prob, Y};
use crate::{
Bbox, DynConf, Keypoint, Mbr, MinOptMax, Ops, Options, OrtEngine, Polygon, Prob, Vision, X, Y,
};
const CXYWH_OFFSET: usize = 4;
const KPT_STEP: usize = 3;
@ -49,8 +51,10 @@ pub struct YOLO {
apply_probs_softmax: bool,
}
impl YOLO {
pub fn new(options: Options) -> Result<Self> {
impl Vision for YOLO {
type Input = DynamicImage;
fn new(options: Options) -> Result<Self> {
let mut engine = OrtEngine::new(&options)?;
let (batch, height, width) = (
engine.batch().to_owned(),
@ -161,32 +165,46 @@ impl YOLO {
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
// pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
fn preprocess(&self, xs: &[Self::Input]) -> Result<Vec<X>> {
let xs_ = match self.task {
YOLOTask::Classify => {
ops::resize(xs, self.height() as u32, self.width() as u32, "bilinear")?
X::resize(xs, self.height() as u32, self.width() as u32, "Bilinear")?
.normalize(0., 255.)?
.nhwc2nchw()?
}
_ => ops::letterbox(
xs,
self.height() as u32,
self.width() as u32,
"catmullRom",
Some(114),
)?,
_ => X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"CatmullRom",
114,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Nhwc2nchw,
])?,
};
let xs_ = ops::normalize(xs_, 0., 255.);
let ys = self.engine.run(&[xs_])?;
self.postprocess(ys, xs)
Ok(vec![xs_])
// let ys = self.engine.run(vec![xs_])?;
// self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
fn inference(&mut self, xs: Vec<X>) -> Result<Vec<X>> {
self.engine.run(xs)
}
// pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
fn postprocess(&self, xs: Vec<X>, xs0: &[Self::Input]) -> Result<Vec<Y>> {
let mut ys = Vec::new();
let protos = if xs.len() == 2 { Some(&xs[1]) } else { None };
for (idx, preds) in xs[0].axis_iter(Axis(0)).enumerate() {
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
// decode
match self.task {
YOLOTask::Classify => {
let y = if self.apply_probs_softmax {
@ -253,7 +271,7 @@ impl YOLO {
let ratio = (self.width() as f32 / image_width)
.min(self.height() as f32 / image_height);
// bboxes
// Detection
for (i, pred) in preds
.axis_iter(if self.anchors_first { Axis(0) } else { Axis(1) })
.enumerate()
@ -324,13 +342,13 @@ impl YOLO {
}
}
// nms
// NMS
let mut y = Y::default().with_bboxes(&y_bboxes);
if self.apply_nms {
y = y.apply_bboxes_nms(self.iou);
}
// keypoints
// Pose
if let YOLOTask::Pose = self.task {
if let Some(bboxes) = y.bboxes() {
let mut y_kpts: Vec<Vec<Keypoint>> = Vec::new();
@ -377,7 +395,7 @@ impl YOLO {
}
}
// masks
// Segment
if let YOLOTask::Segment = self.task {
if let Some(bboxes) = y.bboxes() {
let mut y_polygons: Vec<Polygon> = Vec::new();
@ -391,60 +409,55 @@ impl YOLO {
.slice(s![preds.shape()[0] - self.nm.., bbox.id_born()])
.to_vec()
};
let proto = protos.unwrap().slice(s![idx, .., .., ..]);
let (nm, mh, mw) = proto.dim();
// coefs * proto -> mask
let (nm, nh, nw) = proto.dim();
// coefs * proto => mask (311.427µs)
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
let proto = proto.into_shape((nm, mh * mw))?; // (nm, mh * mw)
let mask = coefs.dot(&proto); // (mh, mw, n)
// build image from ndarray
let mask: ImageBuffer<image::Luma<_>, Vec<f32>> =
match ImageBuffer::from_raw(
nw as u32,
nh as u32,
mask.clone().into_raw_vec(),
// de-scale
let mask = Ops::resize_lumaf32_vec(
&mask.into_raw_vec(),
mw as _,
mh as _,
image_width as _,
image_height as _,
true,
"Bilinear",
)?;
let mut mask: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(
image_width as _,
image_height as _,
mask,
) {
Some(buf) => buf,
None => continue,
Some(x) => x,
};
let mask = image::DynamicImage::from(mask);
// rescale
let mask_original = ops::descale_mask(
mask,
nw as f32,
nh as f32,
image_width,
image_height,
);
let mut mask_original = mask_original.into_luma8();
let (xmin, ymin, xmax, ymax) =
(bbox.xmin(), bbox.ymin(), bbox.xmax(), bbox.ymax());
// crop mask
for y in 0..image_height as usize {
for x in 0..image_width as usize {
if x < bbox.xmin() as usize
|| x > bbox.xmax() as usize
|| y < bbox.ymin() as usize
|| y > bbox.ymax() as usize
// || mask_original.get_pixel(x as u32, y as u32).0 < [127]
// Using bbox to crop the mask (75.93µs)
for (y, row) in mask.enumerate_rows_mut() {
for (x, _, pixel) in row {
if x < xmin as _
|| x > xmax as _
|| y < ymin as _
|| y > ymax as _
{
mask_original.put_pixel(
x as u32,
y as u32,
image::Luma([0u8]),
);
*pixel = image::Luma([0u8]);
}
}
}
// get masks from image
// Find contours (1.413853ms)
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(
&mask_original,
0,
);
imageproc::contours::find_contours_with_threshold(&mask, 0);
let polygon = match contours
.iter()
.map(|x| {
@ -458,6 +471,7 @@ impl YOLO {
None => continue,
Some(x) => x,
};
y_polygons.push(polygon);
}
y = y.with_polygons(&y_polygons);
@ -469,7 +483,9 @@ impl YOLO {
}
Ok(ys)
}
}
impl YOLO {
pub fn batch(&self) -> isize {
self.batch.opt
}

View File

@ -2,7 +2,7 @@ use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array, Axis, IxDyn};
use crate::{ops, Bbox, DynConf, MinOptMax, Options, OrtEngine, Polygon, Y};
use crate::{Bbox, DynConf, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
#[derive(Debug)]
pub struct YOLOPv2 {
@ -37,19 +37,24 @@ impl YOLOPv2 {
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = ops::letterbox(
xs,
self.height() as u32,
self.width() as u32,
"bilinear",
Some(114),
)?;
let xs_ = ops::normalize(xs_, 0., 255.);
let ys = self.engine.run(&[xs_])?;
let xs_ = X::apply(&[
Ops::Letterbox(
xs,
self.height() as u32,
self.width() as u32,
"Bilinear",
114,
"auto",
false,
),
Ops::Normalize(0., 255.),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Vec<Array<f32, IxDyn>>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
pub fn postprocess(&self, xs: Vec<X>, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
let (xs_da, xs_ll, xs_det) = (&xs[0], &xs[1], &xs[2]);
for (idx, ((x_det, x_ll), x_da)) in xs_det
@ -60,7 +65,7 @@ impl YOLOPv2 {
{
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
let (ratio, _, _) = ops::scale_wh(
let (ratio, _, _) = Ops::scale_wh(
image_width,
image_height,
self.width() as f32,
@ -97,32 +102,23 @@ impl YOLOPv2 {
.with_id(id as isize),
);
}
let mut y_polygons: Vec<Polygon> = Vec::new();
// Drivable area
let x_da_0 = x_da.slice(s![0, .., ..]).to_owned();
let x_da_1 = x_da.slice(s![1, .., ..]).to_owned();
let x_da = x_da_1 - x_da_0;
let x_da = x_da
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = x_da
.into_raw_vec()
.iter()
.map(|x| if x < &0.0 { 0.0 } else { 1.0 })
.collect::<Vec<_>>();
let mask_da =
ops::build_dyn_image_from_raw(v, self.height() as u32, self.width() as u32);
let mask_da = ops::descale_mask(
mask_da,
self.width() as f32,
self.height() as f32,
let contours = match self.get_contours_from_mask(
x_da.into_dyn(),
0.0,
self.width() as _,
self.height() as _,
image_width,
image_height,
);
let mask_da = mask_da.into_luma8();
let mut y_polygons: Vec<Polygon> = Vec::new();
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask_da, 0);
) {
Err(_) => continue,
Ok(x) => x,
};
if let Some(polygon) = contours
.iter()
.map(|x| {
@ -137,26 +133,17 @@ impl YOLOPv2 {
};
// Lane line
let x_ll = x_ll
.into_shape((self.height() as usize, self.width() as usize, 1))?
.into_owned();
let v = x_ll
.into_raw_vec()
.iter()
.map(|x| if x < &0.5 { 0.0 } else { 1.0 })
.collect::<Vec<_>>();
let mask_ll =
ops::build_dyn_image_from_raw(v, self.height() as u32, self.width() as u32);
let mask_ll = ops::descale_mask(
mask_ll,
self.width() as f32,
self.height() as f32,
let contours = match self.get_contours_from_mask(
x_ll.to_owned(),
0.5,
self.width() as _,
self.height() as _,
image_width,
image_height,
);
let mask_ll = mask_ll.into_luma8();
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask_ll, 0);
) {
Err(_) => continue,
Ok(x) => x,
};
if let Some(polygon) = contours
.iter()
.map(|x| {
@ -192,4 +179,23 @@ impl YOLOPv2 {
pub fn height(&self) -> isize {
self.height.opt
}
fn get_contours_from_mask(
&self,
mask: Array<f32, IxDyn>,
thresh: f32,
w0: f32,
h0: f32,
w1: f32,
h1: f32,
) -> Result<Vec<imageproc::contours::Contour<i32>>> {
let mask = mask.mapv(|x| if x < thresh { 0u8 } else { 255u8 });
let mask = Ops::resize_luma8_vec(&mask.into_raw_vec(), w0, h0, w1, h1, false, "Bilinear")?;
let mask: image::ImageBuffer<image::Luma<_>, Vec<_>> =
image::ImageBuffer::from_raw(w1 as _, h1 as _, mask)
.ok_or(anyhow::anyhow!("Failed to build image"))?;
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&mask, 0);
Ok(contours)
}
}

View File

@ -1,6 +1,8 @@
use anyhow::Result;
use ndarray::{Array, Axis, Ix2, IxDyn};
use crate::X;
/// Embedding
#[derive(Clone, PartialEq, Default)]
pub struct Embedding(Array<f32, IxDyn>);
@ -11,6 +13,12 @@ impl std::fmt::Debug for Embedding {
}
}
impl From<X> for Embedding {
fn from(x: X) -> Self {
Self(x.0)
}
}
impl Embedding {
pub fn new(x: Array<f32, IxDyn>) -> Self {
Self(x)