Bump the version to 0.0.12

* Add sapiens-seg model
This commit is contained in:
Jamjamjon
2024-08-31 17:10:36 +08:00
committed by GitHub
parent f25f5cf2b5
commit f6755a8be4
23 changed files with 340 additions and 31 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.11"
version = "0.0.12"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"

View File

@ -36,7 +36,7 @@
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
<details>
@ -70,6 +70,7 @@
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
</details>

BIN
assets/paul-george.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

View File

@ -1,7 +1,7 @@
use anyhow::Result;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use usls::{coco, models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion};
use usls::{models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17};
enum Stage {
Pre,
@ -60,7 +60,7 @@ pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> {
.with_i02((320, h, 1280).into())
.with_i03((320, w, 1280).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
.with_names2(&coco::KEYPOINTS_NAMES_17);
.with_names2(&COCO_KEYPOINTS_17);
let mut model = YOLO::new(options)?;
let xs = vec![DataLoader::try_read("./assets/bus.jpg")?];

View File

@ -11,7 +11,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut model = DepthAnything::new(options)?;
// load
let x = vec![DataLoader::try_read("./assets/2.jpg")?];
let x = [DataLoader::try_read("./assets/2.jpg")?];
// run
let y = model.run(&x)?;

View File

@ -1,4 +1,4 @@
use usls::{coco, models::RTMO, Annotator, DataLoader, Options};
use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETONS_16};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build model
@ -19,7 +19,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// annotate
let annotator = Annotator::default()
.with_saveout("RTMO")
.with_skeletons(&coco::SKELETONS_16);
.with_skeletons(&COCO_SKELETONS_16);
annotator.annotate(&x, &y);
Ok(())

30
examples/sapiens/main.rs Normal file
View File

@ -0,0 +1,30 @@
use usls::{
models::{Sapiens, SapiensTask},
Annotator, DataLoader, Options, BODY_PARTS_28,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// build
let options = Options::default()
.with_model("sapiens-seg-0.3b-dyn.onnx")?
.with_sapiens_task(SapiensTask::Seg)
.with_names(&BODY_PARTS_28)
.with_profile(false)
.with_i00((1, 1, 8).into());
let mut model = Sapiens::new(options)?;
// load
let x = [DataLoader::try_read("./assets/paul-george.jpg")?];
// run
let y = model.run(&x)?;
// annotate
let annotator = Annotator::default()
.without_masks(true)
.with_polygons_name(false)
.with_saveout("Sapiens");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -1,7 +1,10 @@
use anyhow::Result;
use clap::Parser;
use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion};
use usls::{
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
COCO_SKELETONS_16,
};
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
@ -174,8 +177,8 @@ fn main() -> Result<()> {
.with_i02((args.height_min, args.height, args.height_max).into())
.with_i03((args.width_min, args.width, args.width_max).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
// .with_names(&coco::NAMES_80)
.with_names2(&coco::KEYPOINTS_NAMES_17)
// .with_names(&COCO_CLASS_NAMES_80)
.with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.with_profile(args.profile);
let mut model = YOLO::new(options)?;
@ -187,7 +190,7 @@ fn main() -> Result<()> {
// build annotator
let annotator = Annotator::default()
.with_skeletons(&coco::SKELETONS_16)
.with_skeletons(&COCO_SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout);

View File

@ -9,6 +9,7 @@ pub mod onnx;
pub mod ops;
mod options;
mod ort_engine;
mod task;
mod tokenizer_stream;
mod ts;
mod vision;
@ -25,6 +26,7 @@ pub use min_opt_max::MinOptMax;
pub use ops::Ops;
pub use options::Options;
pub use ort_engine::OrtEngine;
pub use task::Task;
pub use tokenizer_stream::TokenizerStream;
pub use ts::Ts;
pub use vision::Vision;

View File

@ -7,7 +7,7 @@ use fast_image_resize::{
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
use ndarray::{s, Array, Axis, IntoDimension, IxDyn};
use ndarray::{s, Array, Array3, Axis, IntoDimension, IxDyn};
use rayon::prelude::*;
pub enum Ops<'a> {
@ -159,7 +159,41 @@ impl Ops<'_> {
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
}
pub fn resize_lumaf32_vec(
// pub fn argmax(xs: Array<f32, IxDyn>, d: usize, keep_dims: bool) -> Result<Array<f32, IxDyn>> {
// let mask = Array::zeros(xs.raw_dim());
// todo!();
// }
pub fn interpolate_3d(
xs: Array<f32, IxDyn>,
tw: f32,
th: f32,
filter: &str,
) -> Result<Array<f32, IxDyn>> {
let d_max = xs.ndim();
if d_max != 3 {
anyhow::bail!("`interpolate_3d`: The input's ndim: {} is not 3.", d_max);
}
let (n, h, w) = (xs.shape()[0], xs.shape()[1], xs.shape()[2]);
let mut ys = Array3::zeros((n, th as usize, tw as usize));
for (i, luma) in xs.axis_iter(Axis(0)).enumerate() {
let v = Ops::resize_lumaf32_f32(
&luma.to_owned().into_raw_vec_and_offset().0,
w as _,
h as _,
tw as _,
th as _,
false,
filter,
)?;
let y_ = Array::from_shape_vec((th as usize, tw as usize), v)?;
ys.slice_mut(s![i, .., ..]).assign(&y_);
}
Ok(ys.into_dyn())
}
pub fn resize_lumaf32_u8(
v: &[f32],
w0: f32,
h0: f32,
@ -168,6 +202,20 @@ impl Ops<'_> {
crop_src: bool,
filter: &str,
) -> Result<Vec<u8>> {
let mask_f32 = Self::resize_lumaf32_f32(v, w0, h0, w1, h1, crop_src, filter)?;
let v: Vec<u8> = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect();
Ok(v)
}
pub fn resize_lumaf32_f32(
v: &[f32],
w0: f32,
h0: f32,
w1: f32,
h1: f32,
crop_src: bool,
filter: &str,
) -> Result<Vec<f32>> {
let src = Image::from_vec_u8(
w0 as _,
h0 as _,
@ -189,12 +237,10 @@ impl Ops<'_> {
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
// f32 -> u8
let v: Vec<u8> = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect();
Ok(v)
Ok(mask_f32)
}
pub fn resize_luma8_vec(
pub fn resize_luma8_u8(
v: &[u8],
w0: f32,
h0: f32,

View File

@ -4,7 +4,7 @@ use anyhow::Result;
use crate::{
auto_load,
models::{SamKind, YOLOPreds, YOLOTask, YOLOVersion},
models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion},
Device, MinOptMax,
};
@ -92,6 +92,7 @@ pub struct Options {
pub find_contours: bool,
pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>,
pub sapiens_task: Option<SapiensTask>,
}
impl Default for Options {
@ -175,6 +176,7 @@ impl Default for Options {
find_contours: false,
sam_kind: None,
use_low_res_mask: None,
sapiens_task: None,
}
}
}
@ -220,6 +222,11 @@ impl Options {
self
}
pub fn with_sapiens_task(mut self, x: SapiensTask) -> Self {
self.sapiens_task = Some(x);
self
}
pub fn with_yolo_version(mut self, x: YOLOVersion) -> Self {
self.yolo_version = Some(x);
self

27
src/core/task.rs Normal file
View File

@ -0,0 +1,27 @@
#[derive(Debug, Clone)]
pub enum Task {
// vision
ImageClassification,
ObjectDetection,
KeypointsDetection,
RegisonProposal,
PoseEstimation,
SemanticSegmentation,
InstanceSegmentation,
DepthEstimation,
SurfaceNormalPrediction,
Image2ImageGeneration,
Inpainting,
SuperResolution,
Denoising,
// vl
Tagging,
Captioning,
DetailedCaptioning,
MoreDetailedCaptioning,
PhraseGrounding,
Vqa,
Ocr,
Text2ImageGeneration,
}

View File

@ -24,6 +24,7 @@
//! - [YOLOPv2](https://arxiv.org/abs/2208.11434): Panoptic Driving Perception
//! - [Depth-Anything (v1, v2)](https://github.com/LiheYoung/Depth-Anything): Monocular Depth Estimation
//! - [MODNet](https://github.com/ZHKKKe/MODNet): Image Matting
//! - [Sapiens](https://arxiv.org/abs/2408.12569): Human-centric Vision Tasks
//!
//! # Examples
//!
@ -35,7 +36,7 @@
//! Using provided [`models`] with [`Options`]
//!
//! ```rust, no_run
//! use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision};
//! use usls::{ models::YOLO, Annotator, DataLoader, Options, Vision, COCO_CLASS_NAMES_80};
//!
//! let options = Options::default()
//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
@ -74,7 +75,7 @@
//!
//! ```rust, no_run
//! let options = Options::default()
//! .with_names(&coco::NAMES_80);
//! .with_names(&COCO_CLASS_NAMES_80);
//! ```
//!
//! More options can be found in the [`Options`] documentation.

View File

@ -93,7 +93,7 @@ impl DB {
})
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_vec(
let luma = Ops::resize_luma8_u8(
&v,
self.width() as _,
self.height() as _,

View File

@ -57,7 +57,7 @@ impl DepthAnything {
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_vec(
let luma = Ops::resize_luma8_u8(
&v,
self.width() as _,
self.height() as _,

View File

@ -9,6 +9,7 @@ mod grounding_dino;
mod modnet;
mod rtmo;
mod sam;
mod sapiens;
mod svtr;
mod yolo;
mod yolo_;
@ -23,6 +24,7 @@ pub use grounding_dino::GroundingDINO;
pub use modnet::MODNet;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};
pub use sapiens::{Sapiens, SapiensTask};
pub use svtr::SVTR;
pub use yolo::YOLO;
pub use yolo_::*;

View File

@ -51,7 +51,7 @@ impl MODNet {
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let luma = luma.mapv(|x| (x * 255.0) as u8);
let luma = Ops::resize_luma8_vec(
let luma = Ops::resize_luma8_u8(
&luma.into_raw_vec_and_offset().0,
self.width() as _,
self.height() as _,

View File

@ -264,7 +264,7 @@ impl SAM {
let (h, w) = mask.dim();
let luma = if self.use_low_res_mask {
Ops::resize_lumaf32_vec(
Ops::resize_lumaf32_u8(
&mask.into_owned().into_raw_vec_and_offset().0,
w as _,
h as _,

158
src/models/sapiens.rs Normal file
View File

@ -0,0 +1,158 @@
use anyhow::Result;
use image::DynamicImage;
use ndarray::{s, Array2, Axis};
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y};
#[derive(Debug, Clone, clap::ValueEnum)]
pub enum SapiensTask {
Seg,
Depth,
Normal,
Pose,
}
#[derive(Debug)]
pub struct Sapiens {
engine_seg: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
task: SapiensTask,
names_body: Option<Vec<String>>,
}
impl Sapiens {
pub fn new(options_seg: Options) -> Result<Self> {
let mut engine_seg = OrtEngine::new(&options_seg)?;
let (batch, height, width) = (
engine_seg.batch().to_owned(),
engine_seg.height().to_owned(),
engine_seg.width().to_owned(),
);
let task = options_seg
.sapiens_task
.expect("Error: No sapiens task specified.");
let names_body = options_seg.names;
engine_seg.dry_run()?;
Ok(Self {
engine_seg,
height,
width,
batch,
task,
names_body,
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = X::apply(&[
Ops::Resize(xs, self.height() as u32, self.width() as u32, "Bilinear"),
Ops::Standardize(&[123.5, 116.5, 103.5], &[58.5, 57.0, 57.5], 3),
Ops::Nhwc2nchw,
])?;
match self.task {
SapiensTask::Seg => {
let ys = self.engine_seg.run(Xs::from(xs_))?;
self.postprocess_seg(ys, xs)
}
_ => todo!(),
}
}
pub fn postprocess_seg(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
for (idx, b) in xs[0].axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
// rescale
let masks = Ops::interpolate_3d(b.to_owned(), w1 as _, h1 as _, "Bilinear")?;
// generate mask
let mut mask = Array2::<usize>::zeros((h1 as _, w1 as _));
let mut ids = Vec::new();
for hh in 0..h1 {
for ww in 0..w1 {
let pt_slice = masks.slice(s![.., hh as usize, ww as usize]);
let (i, c) = match pt_slice
.into_iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))
{
Some((i, c)) => (i, c),
None => continue,
};
if *c <= 0. || i == 0 {
continue;
}
mask[[hh as _, ww as _]] = i;
if !ids.contains(&i) {
ids.push(i);
}
}
}
// generate masks and polygons
let mut y_masks: Vec<Mask> = Vec::new();
let mut y_polygons: Vec<Polygon> = Vec::new();
for i in ids.iter() {
let luma = mask.mapv(|x| if x == *i { 255 } else { 0 });
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(
w1 as _,
h1 as _,
luma.into_raw_vec_and_offset().0,
) {
None => continue,
Some(x) => x,
};
// contours
let contours: Vec<imageproc::contours::Contour<i32>> =
imageproc::contours::find_contours_with_threshold(&luma, 0);
let polygon = match contours
.into_iter()
.map(|x| {
let mut polygon = Polygon::default()
.with_id(*i as _)
.with_points_imageproc(&x.points);
if let Some(names_body) = &self.names_body {
polygon = polygon.with_name(&names_body[*i]);
}
polygon
})
.max_by(|x, y| x.area().total_cmp(&y.area()))
{
Some(p) => p,
None => continue,
};
y_polygons.push(polygon);
let mut mask = Mask::default().with_mask(luma).with_id(*i as _);
if let Some(names_body) = &self.names_body {
mask = mask.with_name(&names_body[*i]);
}
y_masks.push(mask);
}
ys.push(Y::default().with_masks(&y_masks).with_polygons(&y_polygons));
}
Ok(ys)
}
pub fn batch(&self) -> isize {
self.batch.opt
}
pub fn width(&self) -> isize {
self.width.opt
}
pub fn height(&self) -> isize {
self.height.opt
}
}

View File

@ -421,7 +421,7 @@ impl Vision for YOLO {
let mask = coefs.dot(&proto); // (mh, mw, n)
// Mask rescale
let mask = Ops::resize_lumaf32_vec(
let mask = Ops::resize_lumaf32_u8(
&mask.into_raw_vec_and_offset().0,
mw as _,
mh as _,

View File

@ -191,7 +191,7 @@ impl YOLOPv2 {
h1: f32,
) -> Result<Vec<imageproc::contours::Contour<i32>>> {
let mask = mask.mapv(|x| if x < thresh { 0u8 } else { 255u8 });
let mask = Ops::resize_luma8_vec(
let mask = Ops::resize_luma8_u8(
&mask.into_raw_vec_and_offset().0,
w0,
h0,

View File

@ -4,10 +4,11 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
pub mod coco;
pub mod colormap256;
pub mod names;
pub use colormap256::*;
pub use names::*;
pub(crate) const GITHUB_ASSETS: &str =
"https://github.com/jamjamjon/assets/releases/download/v0.0.1";

View File

@ -1,6 +1,6 @@
//! Some constants releated with COCO dataset: [`SKELETONS_16`], [`KEYPOINTS_NAMES_17`], [`NAMES_80`]
//! Some constants releated with COCO dataset: [`COCO_SKELETONS_16`], [`COCO_KEYPOINTS_17`], [`COCO_CLASS_NAMES_80`]
pub const SKELETONS_16: [(usize, usize); 16] = [
pub const COCO_SKELETONS_16: [(usize, usize); 16] = [
(0, 1),
(0, 2),
(1, 3),
@ -19,7 +19,7 @@ pub const SKELETONS_16: [(usize, usize); 16] = [
(14, 16),
];
pub const KEYPOINTS_NAMES_17: [&str; 17] = [
pub const COCO_KEYPOINTS_17: [&str; 17] = [
"nose",
"left_eye",
"right_eye",
@ -39,7 +39,7 @@ pub const KEYPOINTS_NAMES_17: [&str; 17] = [
"right_ankle",
];
pub const NAMES_80: [&str; 80] = [
pub const COCO_CLASS_NAMES_80: [&str; 80] = [
"person",
"bicycle",
"car",
@ -121,3 +121,34 @@ pub const NAMES_80: [&str; 80] = [
"hair drier",
"toothbrush",
];
pub const BODY_PARTS_28: [&str; 28] = [
"Background",
"Apparel",
"Face Neck",
"Hair",
"Left Foot",
"Left Hand",
"Left Lower Arm",
"Left Lower Leg",
"Left Shoe",
"Left Sock",
"Left Upper Arm",
"Left Upper Leg",
"Lower Clothing",
"Right Foot",
"Right Hand",
"Right Lower Arm",
"Right Lower Leg",
"Right Shoe",
"Right Sock",
"Right Upper Arm",
"Right Upper Leg",
"Torso",
"Upper Clothing",
"Lower Lip",
"Upper Lip",
"Lower Teeth",
"Upper Teeth",
"Tongue",
];