This commit is contained in:
jamjamjon
2025-05-19 23:25:48 +08:00
parent 5394c9ba27
commit 26de63d239
36 changed files with 4946 additions and 329 deletions

View File

@ -12,7 +12,7 @@ struct Args {
device: String, device: String,
/// model name /// model name
#[argh(option, default = "String::from(\"beit\")")] #[argh(option, default = "String::from(\"mobileone\")")]
model: String, model: String,
/// source image /// source image

View File

@ -27,6 +27,7 @@ fn main() -> Result<()> {
.with_model_dtype(args.dtype.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
.commit()?; .commit()?;
let mut model = DepthPro::new(config)?; let mut model = DepthPro::new(config)?;
// load // load

View File

@ -45,7 +45,7 @@ fn main() -> Result<()> {
annotator.annotate(x, y)?.save(format!( annotator.annotate(x, y)?.save(format!(
"{}.jpg", "{}.jpg",
usls::Dir::Current usls::Dir::Current
.base_dir_with_subs(&["runs", "FastSAM"])? .base_dir_with_subs(&["runs", model.spec()])?
.join(usls::timestamp(None)) .join(usls::timestamp(None))
.display(), .display(),
))?; ))?;

View File

@ -51,7 +51,7 @@ fn main() -> Result<()> {
// owlv2_base() // owlv2_base()
.with_model_dtype(args.dtype.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
.with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>()) .with_text_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>())
.commit()?; .commit()?;
let mut model = OWLv2::new(config)?; let mut model = OWLv2::new(config)?;

View File

@ -1,6 +1,5 @@
## Quick Start ## Quick Start
```Shell ```Shell
cargo run -r -F cuda --example sam2 -- --device cuda --scale t
cargo run -r -F cuda --example sam -- --device cuda --scale t
``` ```

View File

@ -1,7 +1,7 @@
## Quick Start ## Quick Start
```shell ```shell
cargo run -r -F cuda --example sapiens -- --device cuda cargo run -r -F cuda --example sapiens -- --device cuda
``` ```
## Results ## Results

View File

@ -36,7 +36,6 @@ fn main() -> Result<()> {
} }
.with_device_all(args.device.as_str().try_into()?) .with_device_all(args.device.as_str().try_into()?)
.commit()?; .commit()?;
let mut model = SmolVLM::new(config)?; let mut model = SmolVLM::new(config)?;
// load images // load images

View File

@ -22,7 +22,7 @@ fn main() -> Result<()> {
let args: Args = argh::from_env(); let args: Args = argh::from_env();
// build model // build model
let config = ModelConfig::yolo_v8_rtdetr_l() let config = ModelConfig::ultralytics_rtdetr_l()
.with_model_dtype(args.dtype.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
.commit()?; .commit()?;
@ -41,7 +41,7 @@ fn main() -> Result<()> {
annotator.annotate(x, y)?.save(format!( annotator.annotate(x, y)?.save(format!(
"{}.jpg", "{}.jpg",
usls::Dir::Current usls::Dir::Current
.base_dir_with_subs(&["runs", "YOLOv8-RT-DETR"])? .base_dir_with_subs(&["runs", "ultralytics-RTDETR"])?
.join(usls::timestamp(None)) .join(usls::timestamp(None))
.display(), .display(),
))?; ))?;

View File

@ -28,7 +28,6 @@ fn main() -> Result<()> {
.with_scale(Scale::N) .with_scale(Scale::N)
.with_version(8.into()) .with_version(8.into())
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
.auto_yolo_model_file()
.commit()?; .commit()?;
let mut yolo = YOLO::new(options_yolo)?; let mut yolo = YOLO::new(options_yolo)?;

View File

@ -27,27 +27,29 @@ cargo run -r --example yolo -- --task detect --ver v8 --num-classes 6 --model xx
# Classify # Classify
cargo run -r --example yolo -- --task classify --ver 5 --scale s --image-width 224 --image-height 224 --num-classes 1000 --use-imagenet-1k-classes # YOLOv5 cargo run -r --example yolo -- --task classify --ver 5 --scale s --image-width 224 --image-height 224 --num-classes 1000 --use-imagenet-1k-classes # YOLOv5
cargo run -r --example yolo -- --task classify --ver 8 --scale n --image-width 224 --image-height 224 # YOLOv8 cargo run -r --example yolo -- --task classify --ver 8 --scale n --image-width 224 --image-height 224 --use-imagenet-1k-classes # YOLOv8
cargo run -r --example yolo -- --task classify --ver 11 --scale n --image-width 224 --image-height 224 # YOLOv11 cargo run -r --example yolo -- --task classify --ver 11 --scale n --image-width 224 --image-height 224 # YOLO11
# Detect # Detect
cargo run -r --example yolo -- --task detect --ver 5 --scale n --use-coco-80-classes # YOLOv5 cargo run -r --example yolo -- --task detect --ver 5 --scale n --use-coco-80-classes --dtype fp16 # YOLOv5
cargo run -r --example yolo -- --task detect --ver 6 --scale n --use-coco-80-classes # YOLOv6 cargo run -r --example yolo -- --task detect --ver 6 --scale n --use-coco-80-classes --dtype fp16 # YOLOv6
cargo run -r --example yolo -- --task detect --ver 7 --scale t --use-coco-80-classes # YOLOv7 cargo run -r --example yolo -- --task detect --ver 7 --scale t --use-coco-80-classes --dtype fp16 # YOLOv7
cargo run -r --example yolo -- --task detect --ver 8 --scale n --use-coco-80-classes # YOLOv8 cargo run -r --example yolo -- --task detect --ver 8 --scale n --use-coco-80-classes --dtype fp16 # YOLOv8
cargo run -r --example yolo -- --task detect --ver 9 --scale t --use-coco-80-classes # YOLOv9 cargo run -r --example yolo -- --task detect --ver 9 --scale t --use-coco-80-classes --dtype fp16 # YOLOv9
cargo run -r --example yolo -- --task detect --ver 10 --scale n --use-coco-80-classes # YOLOv10 cargo run -r --example yolo -- --task detect --ver 10 --scale n --use-coco-80-classes --dtype fp16 # YOLOv10
cargo run -r --example yolo -- --task detect --ver 11 --scale n --use-coco-80-classes # YOLOv11 cargo run -r --example yolo -- --task detect --ver 11 --scale n --use-coco-80-classes --dtype fp16 # YOLO11
cargo run -r --example yolo -- --task detect --ver 8 --model v8-s-world-v2-shoes.onnx # YOLOv8-world cargo run -r --example yolo -- --task detect --ver 12 --scale n --use-coco-80-classes --dtype fp16 # YOLOv12
cargo run -r --example yolo -- --task detect --ver 8 --model v8-s-world-v2-shoes.onnx # YOLOv8-world
# Pose # Pose
cargo run -r --example yolo -- --task pose --ver 8 --scale n # YOLOv8-Pose cargo run -r --example yolo -- --task pose --ver 8 --scale n # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver 11 --scale n # YOLOv11-Pose cargo run -r --example yolo -- --task pose --ver 11 --scale n # YOLOv11-Pose
# Segment # Segment
cargo run -r --example yolo -- --task segment --ver 5 --scale n # YOLOv5-Segment cargo run -r --example yolo -- --task segment --ver 5 --scale n --use-coco-80-classes --dtype fp16 # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver 8 --scale n # YOLOv8-Segment cargo run -r --example yolo -- --task segment --ver 8 --scale n --use-coco-80-classes --dtype fp16 # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver 11 --scale n # YOLOv8-Segment cargo run -r --example yolo -- --task segment --ver 9 --scale c --use-coco-80-classes --dtype fp16 # YOLOv9-Segment
cargo run -r --example yolo -- --task segment --ver 11 --scale n --use-coco-80-classes --dtype fp16 # YOLO11-Segment
# Obb # Obb
cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv8-Obb cargo run -r --example yolo -- --ver 8 --task obb --scale n --image-width 1024 --image-height 1024 --source images/dota.png # YOLOv8-Obb

View File

@ -5,21 +5,21 @@ use usls::{
}; };
#[derive(argh::FromArgs, Debug)] #[derive(argh::FromArgs, Debug)]
/// Example /// YOLO Example
struct Args { struct Args {
/// model file /// model file(.onnx)
#[argh(option)] #[argh(option)]
model: Option<String>, model: Option<String>,
/// source /// source: image, image folder, video stream
#[argh(option, default = "String::from(\"./assets/bus.jpg\")")] #[argh(option, default = "String::from(\"./assets/bus.jpg\")")]
source: String, source: String,
/// dtype /// model dtype
#[argh(option, default = "String::from(\"auto\")")] #[argh(option, default = "String::from(\"auto\")")]
dtype: String, dtype: String,
/// task /// task: det, seg, pose, classify, obb
#[argh(option, default = "String::from(\"det\")")] #[argh(option, default = "String::from(\"det\")")]
task: String, task: String,
@ -27,101 +27,101 @@ struct Args {
#[argh(option, default = "8.0")] #[argh(option, default = "8.0")]
ver: f32, ver: f32,
/// device /// device: cuda, cpu, mps
#[argh(option, default = "String::from(\"cpu:0\")")] #[argh(option, default = "String::from(\"cpu:0\")")]
device: String, device: String,
/// scale /// scale: n, s, m, l, x
#[argh(option, default = "String::from(\"n\")")] #[argh(option, default = "String::from(\"n\")")]
scale: String, scale: String,
/// trt_fp16 /// enable TensorRT FP16
#[argh(option, default = "true")] #[argh(option, default = "true")]
trt_fp16: bool, trt_fp16: bool,
/// batch_size /// batch size
#[argh(option, default = "1")] #[argh(option, default = "1")]
batch_size: usize, batch_size: usize,
/// min_batch_size /// bin batch size: For TensorRT
#[argh(option, default = "1")] #[argh(option, default = "1")]
min_batch_size: usize, min_batch_size: usize,
/// max_batch_size /// max Batch size: For TensorRT
#[argh(option, default = "4")] #[argh(option, default = "4")]
max_batch_size: usize, max_batch_size: usize,
/// min_image_width /// min image width: For TensorRT
#[argh(option, default = "224")] #[argh(option, default = "224")]
min_image_width: isize, min_image_width: isize,
/// image_width /// image width: For TensorRT
#[argh(option, default = "640")] #[argh(option, default = "640")]
image_width: isize, image_width: isize,
/// max_image_width /// max image width: For TensorRT
#[argh(option, default = "1280")] #[argh(option, default = "1280")]
max_image_width: isize, max_image_width: isize,
/// min_image_height /// min image height: For TensorRT
#[argh(option, default = "224")] #[argh(option, default = "224")]
min_image_height: isize, min_image_height: isize,
/// image_height /// image height: For TensorRT
#[argh(option, default = "640")] #[argh(option, default = "640")]
image_height: isize, image_height: isize,
/// max_image_height /// max image height: For TensorRT
#[argh(option, default = "1280")] #[argh(option, default = "1280")]
max_image_height: isize, max_image_height: isize,
/// num_classes /// num classes
#[argh(option)] #[argh(option)]
num_classes: Option<usize>, num_classes: Option<usize>,
/// num_keypoints /// num keypoints
#[argh(option)] #[argh(option)]
num_keypoints: Option<usize>, num_keypoints: Option<usize>,
/// use_coco_80_classes /// class names
#[argh(switch)]
use_coco_80_classes: bool,
/// use_coco_17_keypoints_classes
#[argh(switch)]
use_coco_17_keypoints_classes: bool,
/// use_imagenet_1k_classes
#[argh(switch)]
use_imagenet_1k_classes: bool,
/// confs
#[argh(option)]
confs: Vec<f32>,
/// keypoint_confs
#[argh(option)]
keypoint_confs: Vec<f32>,
/// exclude_classes
#[argh(option)]
exclude_classes: Vec<usize>,
/// retain_classes
#[argh(option)]
retain_classes: Vec<usize>,
/// class_names
#[argh(option)] #[argh(option)]
class_names: Vec<String>, class_names: Vec<String>,
/// keypoint_names /// keypoint names
#[argh(option)] #[argh(option)]
keypoint_names: Vec<String>, keypoint_names: Vec<String>,
/// topk /// top-k
#[argh(option, default = "5")] #[argh(option, default = "5")]
topk: usize, topk: usize,
/// use COCO 80 classes
#[argh(switch)]
use_coco_80_classes: bool,
/// use COCO 17 keypoints classes
#[argh(switch)]
use_coco_17_keypoints_classes: bool,
/// use ImageNet 1K classes
#[argh(switch)]
use_imagenet_1k_classes: bool,
/// confidences
#[argh(option)]
confs: Vec<f32>,
/// keypoint nonfidences
#[argh(option)]
keypoint_confs: Vec<f32>,
/// exclude nlasses
#[argh(option)]
exclude_classes: Vec<usize>,
/// retain classes
#[argh(option)]
retain_classes: Vec<usize>,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -129,9 +129,7 @@ fn main() -> Result<()> {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init(); .init();
let args: Args = argh::from_env(); let args: Args = argh::from_env();
let mut config = ModelConfig::yolo() let mut config = ModelConfig::yolo()
.with_model_file(&args.model.unwrap_or_default()) .with_model_file(&args.model.unwrap_or_default())
.with_task(args.task.as_str().try_into()?) .with_task(args.task.as_str().try_into()?)
@ -139,7 +137,6 @@ fn main() -> Result<()> {
.with_scale(args.scale.as_str().try_into()?) .with_scale(args.scale.as_str().try_into()?)
.with_model_dtype(args.dtype.as_str().try_into()?) .with_model_dtype(args.dtype.as_str().try_into()?)
.with_model_device(args.device.as_str().try_into()?) .with_model_device(args.device.as_str().try_into()?)
// .with_trt_fp16(args.trt_fp16)
.with_model_trt_fp16(args.trt_fp16) .with_model_trt_fp16(args.trt_fp16)
.with_model_ixx( .with_model_ixx(
0, 0,
@ -174,27 +171,21 @@ fn main() -> Result<()> {
.with_topk(args.topk) .with_topk(args.topk)
.retain_classes(&args.retain_classes) .retain_classes(&args.retain_classes)
.exclude_classes(&args.exclude_classes); .exclude_classes(&args.exclude_classes);
if args.use_coco_80_classes { if args.use_coco_80_classes {
config = config.with_class_names(&NAMES_COCO_80); config = config.with_class_names(&NAMES_COCO_80);
} }
if args.use_coco_17_keypoints_classes { if args.use_coco_17_keypoints_classes {
config = config.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17); config = config.with_keypoint_names(&NAMES_COCO_KEYPOINTS_17);
} }
if args.use_imagenet_1k_classes { if args.use_imagenet_1k_classes {
config = config.with_class_names(&NAMES_IMAGENET_1K); config = config.with_class_names(&NAMES_IMAGENET_1K);
} }
if let Some(nc) = args.num_classes { if let Some(nc) = args.num_classes {
config = config.with_nc(nc); config = config.with_nc(nc);
} }
if let Some(nk) = args.num_keypoints { if let Some(nk) = args.num_keypoints {
config = config.with_nk(nk); config = config.with_nk(nk);
} }
if !args.class_names.is_empty() { if !args.class_names.is_empty() {
config = config.with_class_names( config = config.with_class_names(
&args &args
@ -204,7 +195,6 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
} }
if !args.keypoint_names.is_empty() { if !args.keypoint_names.is_empty() {
config = config.with_keypoint_names( config = config.with_keypoint_names(
&args &args
@ -216,7 +206,7 @@ fn main() -> Result<()> {
} }
// build model // build model
let mut model = YOLO::try_from(config.auto_yolo_model_file().commit()?)?; let mut model = YOLO::new(config.commit()?)?;
// build dataloader // build dataloader
let dl = DataLoader::new(&args.source)? let dl = DataLoader::new(&args.source)?
@ -256,6 +246,7 @@ fn main() -> Result<()> {
} }
} }
// summary
model.summary(); model.summary();
Ok(()) Ok(())

View File

@ -37,9 +37,9 @@ pub struct ModelConfig {
pub processor: ProcessorConfig, pub processor: ProcessorConfig,
// Others // Others
pub class_names: Option<Vec<String>>, // TODO: remove Option pub class_names: Vec<String>,
pub keypoint_names: Option<Vec<String>>, // TODO: remove Option pub keypoint_names: Vec<String>,
pub text_names: Option<Vec<String>>, // TODO: remove Option pub text_names: Vec<String>,
pub class_confs: Vec<f32>, pub class_confs: Vec<f32>,
pub keypoint_confs: Vec<f32>, pub keypoint_confs: Vec<f32>,
pub text_confs: Vec<f32>, pub text_confs: Vec<f32>,
@ -68,9 +68,9 @@ pub struct ModelConfig {
impl Default for ModelConfig { impl Default for ModelConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
class_names: None, class_names: vec![],
keypoint_names: None, keypoint_names: vec![],
text_names: None, text_names: vec![],
class_confs: vec![0.25f32], class_confs: vec![0.25f32],
keypoint_confs: vec![0.3f32], keypoint_confs: vec![0.3f32],
text_confs: vec![0.25f32], text_confs: vec![0.25f32],
@ -130,11 +130,29 @@ impl ModelConfig {
} }
pub fn commit(mut self) -> anyhow::Result<Self> { pub fn commit(mut self) -> anyhow::Result<Self> {
// special case for yolo
if self.name == "yolo" && self.model.file.is_empty() {
// version-scale-task
let mut y = String::new();
if let Some(x) = self.version() {
y.push_str(&x.to_string());
}
if let Some(x) = self.scale() {
y.push_str(&format!("-{}", x));
}
if let Some(x) = self.task() {
y.push_str(&format!("-{}", x.yolo_str()));
}
y.push_str(".onnx");
self.model.file = y;
}
fn try_commit(name: &str, mut m: EngineConfig) -> anyhow::Result<EngineConfig> { fn try_commit(name: &str, mut m: EngineConfig) -> anyhow::Result<EngineConfig> {
if !m.file.is_empty() { if !m.file.is_empty() {
m = m.try_commit(name)?; m = m.try_commit(name)?;
return Ok(m); return Ok(m);
} }
Ok(m) Ok(m)
} }

View File

@ -31,7 +31,7 @@ impl Prob {
.with_confidence(confidence); .with_confidence(confidence);
if let Some(names) = names { if let Some(names) = names {
if id < names.len() { if !names.is_empty() {
meta = meta.with_name(names[id]); meta = meta.with_name(names[id]);
} }
} }

View File

@ -13,6 +13,7 @@ impl crate::ModelConfig {
.with_image_std(&[0.229, 0.224, 0.225]) .with_image_std(&[0.229, 0.224, 0.225])
.with_normalize(true) .with_normalize(true)
.with_apply_softmax(true) .with_apply_softmax(true)
.with_topk(5)
.with_class_names(&NAMES_IMAGENET_1K) .with_class_names(&NAMES_IMAGENET_1K)
} }

View File

@ -11,5 +11,6 @@ impl crate::ModelConfig {
.with_image_std(&[0.5, 0.5, 0.5]) .with_image_std(&[0.5, 0.5, 0.5])
.with_resize_mode(crate::ResizeMode::FitExact) .with_resize_mode(crate::ResizeMode::FitExact)
.with_normalize(true) .with_normalize(true)
.with_model_file("model.onnx")
} }
} }

View File

@ -34,17 +34,14 @@ impl Florence2 {
let encoder = Engine::try_from_config(&config.textual_encoder)?; let encoder = Engine::try_from_config(&config.textual_encoder)?;
let decoder = Engine::try_from_config(&config.textual_decoder)?; let decoder = Engine::try_from_config(&config.textual_decoder)?;
let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?; let decoder_merged = Engine::try_from_config(&config.textual_decoder_merged)?;
let (batch, height, width) = ( let (batch, height, width) = (
vision_encoder.batch().opt(), vision_encoder.batch().opt(),
vision_encoder.try_height().unwrap_or(&1024.into()).opt(), vision_encoder.try_height().unwrap_or(&1024.into()).opt(),
vision_encoder.try_width().unwrap_or(&1024.into()).opt(), vision_encoder.try_width().unwrap_or(&1024.into()).opt(),
); );
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
.with_image_height(height as _); .with_image_height(height as _);
let quantizer = Quantizer::default(); let quantizer = Quantizer::default();
let ts = Ts::merge(&[ let ts = Ts::merge(&[
vision_encoder.ts(), vision_encoder.ts(),

View File

@ -33,24 +33,21 @@ impl GroundingDINO {
engine.try_width().unwrap_or(&1200.into()).opt(), engine.try_width().unwrap_or(&1200.into()).opt(),
engine.ts().clone(), engine.ts().clone(),
); );
let class_names: Vec<_> = config
let class_names = config
.text_names .text_names
.as_ref() .iter()
.and_then(|v| { .map(|s| s.trim().to_ascii_lowercase())
let v: Vec<_> = v .filter(|s| !s.is_empty())
.iter() .collect();
.map(|s| s.trim().to_ascii_lowercase()) if class_names.is_empty() {
.filter(|s| !s.is_empty()) anyhow::bail!(
.collect(); "No valid class names were provided in the config. Ensure the 'text_names' field is non-empty and contains valid class names."
(!v.is_empty()).then_some(v) );
}) }
.ok_or_else(|| anyhow::anyhow!("No valid class names were provided in the config. Ensure the 'text_names' field is non-empty and contains valid class names."))?;
let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| { let text_prompt = class_names.iter().fold(String::new(), |mut acc, text| {
write!(&mut acc, "{}.", text).unwrap(); write!(&mut acc, "{}.", text).unwrap();
acc acc
}); });
let confs_visual = DynConf::new(config.class_confs(), class_names.len()); let confs_visual = DynConf::new(config.class_confs(), class_names.len());
let confs_textual = DynConf::new(config.text_confs(), class_names.len()); let confs_textual = DynConf::new(config.text_confs(), class_names.len());
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?

View File

@ -299,6 +299,7 @@ impl Moondream2 {
cy * image_height as f32, cy * image_height as f32,
)) ))
.with_id(0) .with_id(0)
.with_confidence(1.)
.with_name(object)]); .with_name(object)]);
// keep? // keep?

View File

@ -31,12 +31,12 @@ impl OWLv2 {
engine.ts.clone(), engine.ts.clone(),
); );
let spec = engine.spec().to_owned(); let spec = engine.spec().to_owned();
let names: Vec<String> = config let names: Vec<String> = config.text_names().to_vec();
.class_names() if names.is_empty() {
.expect("No class names specified.") anyhow::bail!(
.iter() "No valid class names were provided in the config. Ensure the 'text_names' field is non-empty and contains valid class names."
.map(|x| x.to_string()) );
.collect(); }
let names_with_prompt: Vec<String> = let names_with_prompt: Vec<String> =
names.iter().map(|x| format!("a photo of {}", x)).collect(); names.iter().map(|x| format!("a photo of {}", x)).collect();
let n = names.len(); let n = names.len();

View File

@ -28,10 +28,7 @@ impl PicoDet {
engine.ts.clone(), engine.ts.clone(),
); );
let spec = engine.spec().to_owned(); let spec = engine.spec().to_owned();
let names = config let names: Vec<String> = config.class_names().to_vec();
.class_names()
.expect("No class names are specified.")
.to_vec();
let confs = DynConf::new(config.class_confs(), names.len()); let confs = DynConf::new(config.class_confs(), names.len());
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
@ -94,14 +91,15 @@ impl PicoDet {
return None; return None;
} }
let (x1, y1, x2, y2) = (pred[2], pred[3], pred[4], pred[5]); let (x1, y1, x2, y2) = (pred[2], pred[3], pred[4], pred[5]);
let mut hbb = Hbb::default()
.with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2)
.with_confidence(confidence)
.with_id(class_id);
if !self.names.is_empty() {
hbb = hbb.with_name(&self.names[class_id]);
}
Some( Some(hbb)
Hbb::default()
.with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2)
.with_confidence(confidence)
.with_id(class_id)
.with_name(&self.names[class_id]),
)
}) })
.collect(); .collect();

View File

@ -40,7 +40,6 @@ impl BaseModelVisual {
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
.with_image_height(height as _); .with_image_height(height as _);
let device = config.model.device; let device = config.model.device;
let task = config.task; let task = config.task;
let scale = config.scale; let scale = config.scale;

View File

@ -3,7 +3,7 @@ use anyhow::Result;
use ndarray::Axis; use ndarray::Axis;
use rayon::prelude::*; use rayon::prelude::*;
use crate::{elapsed, DynConf, Engine, Image, ModelConfig, Prob, Processor, Ts, Xs, Y}; use crate::{elapsed, Engine, Image, ModelConfig, Prob, Processor, Ts, Xs, Y};
#[derive(Debug, Builder)] #[derive(Debug, Builder)]
pub struct ImageClassifier { pub struct ImageClassifier {
@ -12,20 +12,24 @@ pub struct ImageClassifier {
width: usize, width: usize,
batch: usize, batch: usize,
apply_softmax: bool, apply_softmax: bool,
ts: Ts,
processor: Processor, processor: Processor,
confs: DynConf,
nc: usize,
names: Vec<String>, names: Vec<String>,
spec: String, spec: String,
topk: usize,
ts: Ts,
} }
impl TryFrom<ModelConfig> for ImageClassifier { impl TryFrom<ModelConfig> for ImageClassifier {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(config: ModelConfig) -> Result<Self, Self::Error> { fn try_from(config: ModelConfig) -> Result<Self, Self::Error> {
let engine = Engine::try_from_config(&config.model)?; Self::new(config)
}
}
impl ImageClassifier {
pub fn new(config: ModelConfig) -> Result<Self> {
let engine = Engine::try_from_config(&config.model)?;
let spec = engine.spec().to_string(); let spec = engine.spec().to_string();
let (batch, height, width, ts) = ( let (batch, height, width, ts) = (
engine.batch().opt(), engine.batch().opt(),
@ -33,29 +37,9 @@ impl TryFrom<ModelConfig> for ImageClassifier {
engine.try_width().unwrap_or(&224.into()).opt(), engine.try_width().unwrap_or(&224.into()).opt(),
engine.ts().clone(), engine.ts().clone(),
); );
let names = config.class_names.to_vec();
let (nc, names) = match (config.nc(), config.class_names()) {
(Some(nc), Some(names)) => {
if nc != names.len() {
anyhow::bail!(
"The length of the input class names: {} is inconsistent with the number of classes: {}.",
names.len(),
nc
);
}
(nc, names.to_vec())
}
(Some(nc), None) => (
nc,
(0..nc).map(|x| format!("# {}", x)).collect::<Vec<String>>(),
),
(None, Some(names)) => (names.len(), names.to_vec()),
(None, None) => {
anyhow::bail!("Neither class names nor class numbers were specified.");
}
};
let confs = DynConf::new(config.class_confs(), nc);
let apply_softmax = config.apply_softmax.unwrap_or_default(); let apply_softmax = config.apply_softmax.unwrap_or_default();
let topk = config.topk.unwrap_or(5);
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
.with_image_height(height as _); .with_image_height(height as _);
@ -65,18 +49,15 @@ impl TryFrom<ModelConfig> for ImageClassifier {
height, height,
width, width,
batch, batch,
nc,
ts, ts,
spec, spec,
processor, processor,
confs,
names, names,
apply_softmax, apply_softmax,
topk,
}) })
} }
}
impl ImageClassifier {
pub fn summary(&mut self) { pub fn summary(&mut self) {
self.ts.summary(); self.ts.summary();
} }
@ -114,7 +95,7 @@ impl ImageClassifier {
let probs = Prob::new_probs( let probs = Prob::new_probs(
&logits.into_raw_vec_and_offset().0, &logits.into_raw_vec_and_offset().0,
Some(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>()), Some(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>()),
3, self.topk,
); );
Some(Y::default().with_probs(&probs)) Some(Y::default().with_probs(&probs))

View File

@ -28,12 +28,7 @@ impl RFDETR {
engine.ts.clone(), engine.ts.clone(),
); );
let spec = engine.spec().to_owned(); let spec = engine.spec().to_owned();
let names: Vec<String> = config let names: Vec<String> = config.class_names().to_vec();
.class_names()
.expect("No class names specified.")
.iter()
.map(|x| x.to_string())
.collect();
let confs = DynConf::new(config.class_confs(), names.len()); let confs = DynConf::new(config.class_confs(), names.len());
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
@ -107,14 +102,15 @@ impl RFDETR {
let y = cy - h / 2.; let y = cy - h / 2.;
let x = x.max(0.0).min(image_width as _); let x = x.max(0.0).min(image_width as _);
let y = y.max(0.0).min(image_height as _); let y = y.max(0.0).min(image_height as _);
let mut hbb = Hbb::default()
.with_xywh(x, y, w, h)
.with_confidence(conf)
.with_id(class_id as _);
if !self.names.is_empty() {
hbb = hbb.with_name(&self.names[class_id]);
}
Some( Some(hbb)
Hbb::default()
.with_xywh(x, y, w, h)
.with_confidence(conf)
.with_id(class_id as _)
.with_name(&self.names[class_id]),
)
}) })
.collect(); .collect();

View File

@ -28,12 +28,7 @@ impl RTDETR {
engine.ts.clone(), engine.ts.clone(),
); );
let spec = engine.spec().to_owned(); let spec = engine.spec().to_owned();
let names: Vec<String> = config let names: Vec<String> = config.class_names().to_vec();
.class_names()
.expect("No class names specified.")
.iter()
.map(|x| x.to_string())
.collect();
let confs = DynConf::new(config.class_confs(), names.len()); let confs = DynConf::new(config.class_confs(), names.len());
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
@ -94,7 +89,6 @@ impl RTDETR {
if score < self.confs[class_id] { if score < self.confs[class_id] {
continue; continue;
} }
let xyxy = boxes.slice(s![i, ..]); let xyxy = boxes.slice(s![i, ..]);
let (x1, y1, x2, y2) = ( let (x1, y1, x2, y2) = (
xyxy[0] / ratio, xyxy[0] / ratio,
@ -102,13 +96,14 @@ impl RTDETR {
xyxy[2] / ratio, xyxy[2] / ratio,
xyxy[3] / ratio, xyxy[3] / ratio,
); );
y_bboxes.push( let mut hbb = Hbb::default()
Hbb::default() .with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2)
.with_xyxy(x1.max(0.0f32), y1.max(0.0f32), x2, y2) .with_confidence(score)
.with_confidence(score) .with_id(class_id);
.with_id(class_id) if !self.names.is_empty() {
.with_name(&self.names[class_id]), hbb = hbb.with_name(&self.names[class_id]);
); }
y_bboxes.push(hbb);
} }
let mut y = Y::default(); let mut y = Y::default();

View File

@ -27,7 +27,6 @@ impl RTMO {
engine.try_width().unwrap_or(&512.into()).opt(), engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(), engine.ts().clone(),
); );
let nk = config.nk().unwrap_or(17); let nk = config.nk().unwrap_or(17);
let confs = DynConf::new(config.class_confs(), 1); let confs = DynConf::new(config.class_confs(), 1);
let kconfs = DynConf::new(config.keypoint_confs(), nk); let kconfs = DynConf::new(config.keypoint_confs(), nk);

View File

@ -30,7 +30,6 @@ impl SAM2 {
); );
let ts = Ts::merge(&[encoder.ts(), decoder.ts()]); let ts = Ts::merge(&[encoder.ts(), decoder.ts()]);
let spec = encoder.spec().to_owned(); let spec = encoder.spec().to_owned();
let conf = DynConf::new(config.class_confs(), 1); let conf = DynConf::new(config.class_confs(), 1);
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)

View File

@ -11,7 +11,7 @@ pub struct Sapiens {
width: usize, width: usize,
batch: usize, batch: usize,
task: Task, task: Task,
names_body: Option<Vec<String>>, names_body: Vec<String>,
ts: Ts, ts: Ts,
processor: Processor, processor: Processor,
spec: String, spec: String,
@ -27,7 +27,6 @@ impl Sapiens {
engine.try_width().unwrap_or(&768.into()).opt(), engine.try_width().unwrap_or(&768.into()).opt(),
engine.ts().clone(), engine.ts().clone(),
); );
let task = config.task.expect("No sapiens task specified."); let task = config.task.expect("No sapiens task specified.");
let names_body = config.class_names; let names_body = config.class_names;
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
@ -124,8 +123,8 @@ impl Sapiens {
if let Some(polygon) = mask.polygon() { if let Some(polygon) = mask.polygon() {
y_polygons.push(polygon); y_polygons.push(polygon);
} }
if let Some(names_body) = &self.names_body { if !self.names_body.is_empty() {
mask = mask.with_name(&names_body[*i]); mask = mask.with_name(&self.names_body[*i]);
} }
y_masks.push(mask); y_masks.push(mask);
} }

View File

@ -15,7 +15,7 @@ impl crate::ModelConfig {
.with_scale(crate::Scale::Million(256.)) .with_scale(crate::Scale::Million(256.))
.with_visual_file("256m-vision-encoder.onnx") .with_visual_file("256m-vision-encoder.onnx")
.with_textual_file("256m-embed-tokens.onnx") .with_textual_file("256m-embed-tokens.onnx")
.with_textual_decoder_file("256m-decoder-model-merged.onnx") .with_textual_decoder_merged_file("256m-decoder-model-merged.onnx")
} }
pub fn smolvlm_500m() -> Self { pub fn smolvlm_500m() -> Self {
@ -23,6 +23,6 @@ impl crate::ModelConfig {
.with_scale(crate::Scale::Million(500.)) .with_scale(crate::Scale::Million(500.))
.with_visual_file("500m-vision-encoder.onnx") .with_visual_file("500m-vision-encoder.onnx")
.with_textual_file("500m-embed-tokens.onnx") .with_textual_file("500m-embed-tokens.onnx")
.with_textual_decoder_file("500m-decoder-model-merged.onnx") .with_textual_decoder_merged_file("500m-decoder-model-merged.onnx")
} }
} }

View File

@ -36,7 +36,6 @@ impl SmolVLM {
let vision = Engine::try_from_config(&config.visual)?; let vision = Engine::try_from_config(&config.visual)?;
let text_embed = Engine::try_from_config(&config.textual)?; let text_embed = Engine::try_from_config(&config.textual)?;
let decoder = Engine::try_from_config(&config.textual_decoder_merged)?; let decoder = Engine::try_from_config(&config.textual_decoder_merged)?;
let fake_image_token = "<fake_token_around_image>".to_string(); let fake_image_token = "<fake_token_around_image>".to_string();
let image_token = "<image>".to_string(); let image_token = "<image>".to_string();
let global_img_token = "<global-img>".to_string(); let global_img_token = "<global-img>".to_string();
@ -52,7 +51,6 @@ impl SmolVLM {
_ => unimplemented!(), _ => unimplemented!(),
}; };
let scale = config.scale.clone().unwrap(); let scale = config.scale.clone().unwrap();
let (batch, num_patch, height, width, ts) = ( let (batch, num_patch, height, width, ts) = (
vision.batch().opt(), vision.batch().opt(),
vision.inputs_minoptmax()[0][1].opt(), vision.inputs_minoptmax()[0][1].opt(),

View File

@ -65,7 +65,6 @@ impl TrOCR {
Some(Scale::B) => 12, Some(Scale::B) => 12,
_ => unimplemented!(), _ => unimplemented!(),
}; };
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?
.with_image_width(width as _) .with_image_width(width as _)
.with_image_height(height as _); .with_image_height(height as _);

View File

@ -1,6 +1,7 @@
use crate::{ use crate::{
models::YOLOPredsFormat, ModelConfig, ResizeMode, Scale, Task, NAMES_COCO_80, models::YOLOPredsFormat, ModelConfig, ResizeMode, Scale, Task, NAMES_COCO_80,
NAMES_COCO_KEYPOINTS_17, NAMES_IMAGENET_1K, NAMES_YOLO_DOCLAYOUT_10, NAMES_COCO_KEYPOINTS_17, NAMES_DOTA_V1_15, NAMES_IMAGENET_1K, NAMES_YOLOE_4585,
NAMES_YOLO_DOCLAYOUT_10,
}; };
impl ModelConfig { impl ModelConfig {
@ -13,7 +14,6 @@ impl ModelConfig {
.with_model_ixx(0, 3, 640.into()) .with_model_ixx(0, 3, 640.into())
.with_resize_mode(ResizeMode::FitAdaptive) .with_resize_mode(ResizeMode::FitAdaptive)
.with_resize_filter("CatmullRom") .with_resize_filter("CatmullRom")
.with_class_names(&NAMES_COCO_80)
} }
pub fn yolo_classify() -> Self { pub fn yolo_classify() -> Self {
@ -27,7 +27,9 @@ impl ModelConfig {
} }
pub fn yolo_detect() -> Self { pub fn yolo_detect() -> Self {
Self::yolo().with_task(Task::ObjectDetection) Self::yolo()
.with_task(Task::ObjectDetection)
.with_class_names(&NAMES_COCO_80)
} }
pub fn yolo_pose() -> Self { pub fn yolo_pose() -> Self {
@ -37,31 +39,17 @@ impl ModelConfig {
} }
pub fn yolo_segment() -> Self { pub fn yolo_segment() -> Self {
Self::yolo().with_task(Task::InstanceSegmentation) Self::yolo()
.with_task(Task::InstanceSegmentation)
.with_class_names(&NAMES_COCO_80)
} }
pub fn yolo_obb() -> Self { pub fn yolo_obb() -> Self {
Self::yolo().with_task(Task::OrientedObjectDetection) Self::yolo()
} .with_model_ixx(0, 2, 1024.into())
.with_model_ixx(0, 3, 1024.into())
pub fn auto_yolo_model_file(mut self) -> Self { .with_task(Task::OrientedObjectDetection)
if self.model.file.is_empty() { .with_class_names(&NAMES_DOTA_V1_15)
// [version]-[scale]-[task]
let mut y = String::new();
if let Some(x) = self.version() {
y.push_str(&x.to_string());
}
if let Some(x) = self.scale() {
y.push_str(&format!("-{}", x));
}
if let Some(x) = self.task() {
y.push_str(&format!("-{}", x.yolo_str()));
}
y.push_str(".onnx");
self.model.file = y;
}
self
} }
pub fn doclayout_yolo_docstructbench() -> Self { pub fn doclayout_yolo_docstructbench() -> Self {
@ -75,60 +63,80 @@ impl ModelConfig {
} }
// YOLOE models // YOLOE models
pub fn yoloe() -> Self {
Self::yolo()
.with_task(Task::InstanceSegmentation)
.with_class_names(&NAMES_YOLOE_4585)
}
pub fn yoloe_v8s_seg_pf() -> Self { pub fn yoloe_v8s_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(8.into()) .with_version(8.into())
.with_scale(Scale::S) .with_scale(Scale::S)
.with_model_file("yoloe-v8s-seg-pf.onnx") .with_model_file("yoloe-v8s-seg-pf.onnx")
} }
pub fn yoloe_v8m_seg_pf() -> Self { pub fn yoloe_v8m_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(8.into()) .with_version(8.into())
.with_scale(Scale::M) .with_scale(Scale::M)
.with_model_file("yoloe-v8m-seg-pf.onnx") .with_model_file("yoloe-v8m-seg-pf.onnx")
} }
pub fn yoloe_v8l_seg_pf() -> Self { pub fn yoloe_v8l_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(8.into()) .with_version(8.into())
.with_scale(Scale::L) .with_scale(Scale::L)
.with_model_file("yoloe-v8l-seg-pf.onnx") .with_model_file("yoloe-v8l-seg-pf.onnx")
} }
pub fn yoloe_11s_seg_pf() -> Self { pub fn yoloe_11s_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(11.into()) .with_version(11.into())
.with_scale(Scale::S) .with_scale(Scale::S)
.with_model_file("yoloe-11s-seg-pf.onnx") .with_model_file("yoloe-11s-seg-pf.onnx")
} }
pub fn yoloe_11m_seg_pf() -> Self { pub fn yoloe_11m_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(11.into()) .with_version(11.into())
.with_scale(Scale::M) .with_scale(Scale::M)
.with_model_file("yoloe-v8m-seg-pf.onnx") .with_model_file("yoloe-v8m-seg-pf.onnx")
} }
pub fn yoloe_11l_seg_pf() -> Self { pub fn yoloe_11l_seg_pf() -> Self {
Self::yolo_segment() Self::yoloe()
.with_version(11.into()) .with_version(11.into())
.with_scale(Scale::L) .with_scale(Scale::L)
.with_model_file("yoloe-11l-seg-pf.onnx") .with_model_file("yoloe-11l-seg-pf.onnx")
} }
/// ---- TODO
pub fn fastsam_s() -> Self { pub fn fastsam_s() -> Self {
Self::yolo_segment() Self::yolo_segment()
.with_class_names(&["object"])
.with_scale(Scale::S) .with_scale(Scale::S)
.with_version(8.into()) .with_version(8.into())
.with_model_file("FastSAM-s.onnx") .with_model_file("FastSAM-s.onnx")
} }
pub fn yolo_v8_rtdetr_l() -> Self { pub fn fastsam_x() -> Self {
Self::yolo_segment()
.with_class_names(&["object"])
.with_scale(Scale::X)
.with_version(8.into())
.with_model_file("FastSAM-x.onnx")
}
pub fn ultralytics_rtdetr_l() -> Self {
Self::yolo_detect() Self::yolo_detect()
.with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n()) .with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
.with_scale(Scale::L) .with_scale(Scale::L)
.with_model_file("rtdetr-l-det.onnx") .with_model_file("rtdetr-l.onnx")
}
pub fn ultralytics_rtdetr_x() -> Self {
Self::yolo_detect()
.with_yolo_preds_format(YOLOPredsFormat::n_a_cxcywh_clss_n())
.with_scale(Scale::X)
.with_model_file("rtdetr-x.onnx")
} }
} }

View File

@ -28,12 +28,12 @@ pub struct YOLO {
confs: DynConf, confs: DynConf,
kconfs: DynConf, kconfs: DynConf,
iou: f32, iou: f32,
topk: usize,
processor: Processor, processor: Processor,
ts: Ts, ts: Ts,
spec: String, spec: String,
classes_excluded: Vec<usize>, classes_excluded: Vec<usize>,
classes_retained: Vec<usize>, classes_retained: Vec<usize>,
topk: usize,
} }
impl TryFrom<ModelConfig> for YOLO { impl TryFrom<ModelConfig> for YOLO {
@ -47,7 +47,6 @@ impl TryFrom<ModelConfig> for YOLO {
impl YOLO { impl YOLO {
pub fn new(config: ModelConfig) -> Result<Self> { pub fn new(config: ModelConfig) -> Result<Self> {
let engine = Engine::try_from_config(&config.model)?; let engine = Engine::try_from_config(&config.model)?;
let (batch, height, width, ts, spec) = ( let (batch, height, width, ts, spec) = (
engine.batch().opt(), engine.batch().opt(),
engine.try_height().unwrap_or(&640.into()).opt(), engine.try_height().unwrap_or(&640.into()).opt(),
@ -55,7 +54,6 @@ impl YOLO {
engine.ts.clone(), engine.ts.clone(),
engine.spec().to_owned(), engine.spec().to_owned(),
); );
let task: Option<Task> = match &config.task { let task: Option<Task> = match &config.task {
Some(task) => Some(task.clone()), Some(task) => Some(task.clone()),
None => match engine.try_fetch("task") { None => match engine.try_fetch("task") {
@ -128,9 +126,10 @@ impl YOLO {
(Task::InstanceSegmentation, Version(5, 0, _)) => { (Task::InstanceSegmentation, Version(5, 0, _)) => {
YOLOPredsFormat::n_a_cxcywh_confclss_coefs() YOLOPredsFormat::n_a_cxcywh_confclss_coefs()
} }
(Task::InstanceSegmentation, Version(8, 0, _) | Version(11, 0, _)) => { (
YOLOPredsFormat::n_cxcywh_clss_coefs_a() Task::InstanceSegmentation,
} Version(8, 0, _) | Version(9, 0, _) | Version(11, 0, _),
) => YOLOPredsFormat::n_cxcywh_clss_coefs_a(),
(Task::OrientedObjectDetection, Version(8, 0, _) | Version(11, 0, _)) => { (Task::OrientedObjectDetection, Version(8, 0, _) | Version(11, 0, _)) => {
YOLOPredsFormat::n_cxcywh_clss_r_a() YOLOPredsFormat::n_cxcywh_clss_r_a()
} }
@ -169,61 +168,69 @@ impl YOLO {
}; };
// Class names // Class names
let names: Option<Vec<String>> = match Self::fetch_names_from_onnx(&engine) { let names_parsed = Self::fetch_names_from_onnx(&engine);
Some(names_parsed) => match &config.class_names { let names_customized = config.class_names.to_vec();
Some(names) => { let names: Vec<_> = match (names_parsed, names_customized.is_empty()) {
if names.len() == names_parsed.len() { (None, true) => vec![],
// prioritize user-defined (None, false) => names_customized,
Some(names.clone()) (Some(names_parsed), true) => names_parsed,
} else { (Some(names_parsed), false) => {
// Fail to override if names_parsed.len() == names_customized.len() {
anyhow::bail!( names_customized // prioritize user-defined
"The lengths of parsed class names: {} and user-defined class names: {} do not match.", } else {
names_parsed.len(), anyhow::bail!(
names.len(), "The lengths of parsed class names: {} and user-defined class names: {} do not match.",
) names_parsed.len(),
} names_customized.len(),
);
} }
None => Some(names_parsed),
},
None => config.class_names.clone(),
};
// Class names & Number of class
let (nc, names) = match (config.nc(), names) {
(_, Some(names)) => (names.len(), names.to_vec()),
(Some(nc), None) => (nc, Self::n2s(nc)),
(None, None) => {
anyhow::bail!(
"Neither class names nor the number of classes were specified. \
\nConsider specify them with `ModelConfig::default().with_nc()` or `ModelConfig::default().with_class_names()`"
);
} }
}; };
// Class names & Number of class
let nc = match config.nc() {
None => names.len(),
Some(n) => {
if names.len() != n {
anyhow::bail!(
"The lengths of class names: {} and user-defined num_classes: {} do not match.",
names.len(),
n,
)
}
n
}
};
if nc == 0 && names.is_empty() {
anyhow::bail!(
"Neither class names nor the number of classes were specified. \
\nConsider specify them with `ModelConfig::default().with_nc()` or `ModelConfig::default().with_class_names()`"
);
}
// Keypoint names & Number of keypoints // Keypoint names & Number of keypoints
let (nk, names_kpt) = if let Task::KeypointsDetection = task { let names_kpt = config.keypoint_names.to_vec();
let nk = Self::fetch_nk_from_onnx(&engine).or(config.nk()); let nk = if let Task::KeypointsDetection = task {
match (&config.keypoint_names, nk) { match (names_kpt.is_empty(), Self::fetch_nk_from_onnx(&engine).or(config.nk())) {
(Some(names), Some(nk)) => { (false, Some(nk)) => {
if names.len() != nk { if names_kpt.len() != nk {
anyhow::bail!( anyhow::bail!(
"The lengths of user-defined keypoint names: {} and nk parsed: {} do not match.", "The lengths of user-defined keypoint class names: {} and num_keypoints: {} do not match.",
names.len(), names_kpt.len(),
nk, nk,
); );
} }
(nk, names.clone()) nk
} },
(Some(names), None) => (names.len(), names.clone()), (false, None) => names_kpt.len(),
(None, Some(nk)) => (nk, Self::n2s(nk)), (true, Some(nk)) => nk,
(None, None) => anyhow::bail!( (true, None) => anyhow::bail!(
"Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \ "Neither keypoint names nor the number of keypoints were specified when doing `KeypointsDetection` task. \
\nConsider specify them with `ModelConfig::default().with_nk()` or `ModelConfig::default().with_keypoint_names()`" \nConsider specify them with `ModelConfig::default().with_nk()` or `ModelConfig::default().with_keypoint_names()`"
), ),
} }
} else { } else {
(0, vec![]) 0
}; };
// Attributes // Attributes
@ -295,12 +302,8 @@ impl YOLO {
Ok(ys) Ok(ys)
} }
pub fn summary(&mut self) {
self.ts.summary();
}
fn postprocess(&self, xs: Xs) -> Result<Vec<Y>> { fn postprocess(&self, xs: Xs) -> Result<Vec<Y>> {
let protos = if xs.len() == 2 { Some(&xs[1]) } else { None }; // let protos = if xs.len() == 2 { Some(&xs[1]) } else { None };
let ys: Vec<Y> = xs[0] let ys: Vec<Y> = xs[0]
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.into_par_iter() .into_par_iter()
@ -345,11 +348,11 @@ impl YOLO {
let ratio = self.processor.images_transform_info[idx].height_scale; let ratio = self.processor.images_transform_info[idx].height_scale;
// Other tasks // Other tasks
let (y_bboxes, y_mbrs) = slice_bboxes? let (y_hbbs, y_obbs) = slice_bboxes?
.axis_iter(Axis(0)) .axis_iter(Axis(0))
.into_par_iter() .into_par_iter()
.enumerate() .enumerate()
.filter_map(|(i, bbox)| { .filter_map(|(i, hbb)| {
// confidence & class_id // confidence & class_id
let (class_id, confidence) = match &slice_id { let (class_id, confidence) = match &slice_id {
Some(ids) => (ids[[i, 0]] as _, slice_clss[[i, 0]] as _), Some(ids) => (ids[[i, 0]] as _, slice_clss[[i, 0]] as _),
@ -389,50 +392,50 @@ impl YOLO {
} }
// Bboxes // Bboxes
let bbox = bbox.mapv(|x| x / ratio); let hbb = hbb.mapv(|x| x / ratio);
let bbox = if self.layout.is_bbox_normalized { let hbb = if self.layout.is_bbox_normalized {
( (
bbox[0] * self.width() as f32, hbb[0] * self.width() as f32,
bbox[1] * self.height() as f32, hbb[1] * self.height() as f32,
bbox[2] * self.width() as f32, hbb[2] * self.width() as f32,
bbox[3] * self.height() as f32, hbb[3] * self.height() as f32,
) )
} else { } else {
(bbox[0], bbox[1], bbox[2], bbox[3]) (hbb[0], hbb[1], hbb[2], hbb[3])
}; };
let (cx, cy, x, y, w, h) = match self.layout.box_type()? { let (cx, cy, x, y, w, h) = match self.layout.box_type()? {
BoxType::Cxcywh => { BoxType::Cxcywh => {
let (cx, cy, w, h) = bbox; let (cx, cy, w, h) = hbb;
let x = (cx - w / 2.).max(0.); let x = (cx - w / 2.).max(0.);
let y = (cy - h / 2.).max(0.); let y = (cy - h / 2.).max(0.);
(cx, cy, x, y, w, h) (cx, cy, x, y, w, h)
} }
BoxType::Xyxy => { BoxType::Xyxy => {
let (x, y, x2, y2) = bbox; let (x, y, x2, y2) = hbb;
let (w, h) = (x2 - x, y2 - y); let (w, h) = (x2 - x, y2 - y);
let (cx, cy) = ((x + x2) / 2., (y + y2) / 2.); let (cx, cy) = ((x + x2) / 2., (y + y2) / 2.);
(cx, cy, x, y, w, h) (cx, cy, x, y, w, h)
} }
BoxType::Xywh => { BoxType::Xywh => {
let (x, y, w, h) = bbox; let (x, y, w, h) = hbb;
let (cx, cy) = (x + w / 2., y + h / 2.); let (cx, cy) = (x + w / 2., y + h / 2.);
(cx, cy, x, y, w, h) (cx, cy, x, y, w, h)
} }
BoxType::Cxcyxy => { BoxType::Cxcyxy => {
let (cx, cy, x2, y2) = bbox; let (cx, cy, x2, y2) = hbb;
let (w, h) = ((x2 - cx) * 2., (y2 - cy) * 2.); let (w, h) = ((x2 - cx) * 2., (y2 - cy) * 2.);
let x = (x2 - w).max(0.); let x = (x2 - w).max(0.);
let y = (y2 - h).max(0.); let y = (y2 - h).max(0.);
(cx, cy, x, y, w, h) (cx, cy, x, y, w, h)
} }
BoxType::XyCxcy => { BoxType::XyCxcy => {
let (x, y, cx, cy) = bbox; let (x, y, cx, cy) = hbb;
let (w, h) = ((cx - x) * 2., (cy - y) * 2.); let (w, h) = ((cx - x) * 2., (cy - y) * 2.);
(cx, cy, x, y, w, h) (cx, cy, x, y, w, h)
} }
}; };
let (y_bbox, y_mbr) = match &slice_radians { let (y_hbb, y_obb) = match &slice_radians {
Some(slice_radians) => { Some(slice_radians) => {
let radians = slice_radians[[i, 0]]; let radians = slice_radians[[i, 0]];
let (w, h, radians) = if w > h { let (w, h, radians) = if w > h {
@ -441,47 +444,51 @@ impl YOLO {
(h, w, radians + std::f32::consts::PI / 2.) (h, w, radians + std::f32::consts::PI / 2.)
}; };
let radians = radians % std::f32::consts::PI; let radians = radians % std::f32::consts::PI;
let mbr = Obb::from_cxcywhr(cx, cy, w, h, radians) let mut obb = Obb::from_cxcywhr(cx, cy, w, h, radians)
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id) .with_id(class_id);
.with_name(&self.names[class_id]); if !self.names.is_empty() {
obb = obb.with_name(&self.names[class_id]);
}
(None, Some(mbr)) (None, Some(obb))
} }
None => { None => {
let bbox = Hbb::default() let mut hbb = Hbb::default()
.with_xywh(x, y, w, h) .with_xywh(x, y, w, h)
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id) .with_id(class_id)
.with_uid(i) .with_uid(i);
.with_name(&self.names[class_id]); if !self.names.is_empty() {
hbb = hbb.with_name(&self.names[class_id]);
}
(Some(bbox), None) (Some(hbb), None)
} }
}; };
Some((y_bbox, y_mbr)) Some((y_hbb, y_obb))
}) })
.collect::<(Vec<_>, Vec<_>)>(); .collect::<(Vec<_>, Vec<_>)>();
let mut y_bboxes: Vec<Hbb> = y_bboxes.into_iter().flatten().collect(); let mut y_hbbs: Vec<Hbb> = y_hbbs.into_iter().flatten().collect();
let mut y_mbrs: Vec<Obb> = y_mbrs.into_iter().flatten().collect(); let mut y_obbs: Vec<Obb> = y_obbs.into_iter().flatten().collect();
// Mbrs // Mbrs
if !y_mbrs.is_empty() { if !y_obbs.is_empty() {
if self.layout.apply_nms { if self.layout.apply_nms {
y_mbrs.apply_nms_inplace(self.iou); y_obbs.apply_nms_inplace(self.iou);
} }
y = y.with_obbs(&y_mbrs); y = y.with_obbs(&y_obbs);
return Some(y); return Some(y);
} }
// Bboxes // Bboxes
if !y_bboxes.is_empty() { if !y_hbbs.is_empty() {
if self.layout.apply_nms { if self.layout.apply_nms {
y_bboxes.apply_nms_inplace(self.iou); y_hbbs.apply_nms_inplace(self.iou);
} }
y = y.with_hbbs(&y_bboxes); y = y.with_hbbs(&y_hbbs);
} }
// KeypointsDetection // KeypointsDetection
@ -490,8 +497,8 @@ impl YOLO {
if let Some(hbbs) = y.hbbs() { if let Some(hbbs) = y.hbbs() {
let y_kpts = hbbs let y_kpts = hbbs
.into_par_iter() .into_par_iter()
.filter_map(|bbox| { .filter_map(|hbb| {
let pred = pred_kpts.slice(s![bbox.uid(), ..]); let pred = pred_kpts.slice(s![hbb.uid(), ..]);
let kpts = (0..self.nk) let kpts = (0..self.nk)
.into_par_iter() .into_par_iter()
.map(|i| { .map(|i| {
@ -501,14 +508,17 @@ impl YOLO {
if kconf < self.kconfs[i] { if kconf < self.kconfs[i] {
Keypoint::default() Keypoint::default()
} else { } else {
Keypoint::default() let mut kpt = Keypoint::default()
.with_id(i) .with_id(i)
.with_confidence(kconf) .with_confidence(kconf)
.with_xy( .with_xy(
kx.max(0.0f32).min(image_width as f32), kx.max(0.0f32).min(image_width as f32),
ky.max(0.0f32).min(image_height as f32), ky.max(0.0f32).min(image_height as f32),
) );
.with_name(&self.names_kpt[i]) if !self.names_kpt.is_empty() {
kpt = kpt.with_name(&self.names_kpt[i]);
}
kpt
} }
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -522,11 +532,12 @@ impl YOLO {
// InstanceSegmentation // InstanceSegmentation
if let Some(coefs) = slice_coefs { if let Some(coefs) = slice_coefs {
if let Some(hbbs) = y.hbbs() { if let Some(hbbs) = y.hbbs() {
let protos = &xs[1];
let y_masks = hbbs let y_masks = hbbs
.into_par_iter() .into_par_iter()
.filter_map(|bbox| { .filter_map(|hbb| {
let coefs = coefs.slice(s![bbox.uid(), ..]).to_vec(); let coefs = coefs.slice(s![hbb.uid(), ..]).to_vec();
let proto = protos.as_ref()?.slice(s![idx, .., .., ..]); let proto = protos.slice(s![idx, .., .., ..]);
let (nm, mh, mw) = proto.dim(); let (nm, mh, mw) = proto.dim();
// coefs * proto => mask // coefs * proto => mask
@ -553,9 +564,9 @@ impl YOLO {
mask, mask,
)?; )?;
let (xmin, ymin, xmax, ymax) = let (xmin, ymin, xmax, ymax) =
(bbox.xmin(), bbox.ymin(), bbox.xmax(), bbox.ymax()); (hbb.xmin(), hbb.ymin(), hbb.xmax(), hbb.ymax());
// Using bbox to crop the mask // Using hbb to crop the mask
for (y, row) in mask.enumerate_rows_mut() { for (y, row) in mask.enumerate_rows_mut() {
for (x, _, pixel) in row { for (x, _, pixel) in row {
if x < xmin as _ if x < xmin as _
@ -569,10 +580,10 @@ impl YOLO {
} }
let mut mask = Mask::default().with_mask(mask); let mut mask = Mask::default().with_mask(mask);
if let Some(id) = bbox.id() { if let Some(id) = hbb.id() {
mask = mask.with_id(id); mask = mask.with_id(id);
} }
if let Some(name) = bbox.name() { if let Some(name) = hbb.name() {
mask = mask.with_name(name); mask = mask.with_name(name);
} }
@ -613,7 +624,7 @@ impl YOLO {
.and_then(|m| m.as_str().parse::<usize>().ok()) .and_then(|m| m.as_str().parse::<usize>().ok())
} }
fn n2s(n: usize) -> Vec<String> { pub fn summary(&mut self) {
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>() self.ts.summary();
} }
} }

View File

@ -29,7 +29,6 @@ impl YOLOPv2 {
engine.try_width().unwrap_or(&512.into()).opt(), engine.try_width().unwrap_or(&512.into()).opt(),
engine.ts().clone(), engine.ts().clone(),
); );
let confs = DynConf::new(config.class_confs(), 80); let confs = DynConf::new(config.class_confs(), 80);
let iou = config.iou.unwrap_or(0.45f32); let iou = config.iou.unwrap_or(0.45f32);
let processor = Processor::try_from_config(&config.processor)? let processor = Processor::try_from_config(&config.processor)?

View File

@ -6,7 +6,13 @@ pub struct DynConf(Vec<f32>);
impl Default for DynConf { impl Default for DynConf {
fn default() -> Self { fn default() -> Self {
Self(vec![0.4f32]) Self(vec![0.3f32])
}
}
impl From<f32> for DynConf {
fn from(conf: f32) -> Self {
Self(vec![conf])
} }
} }

File diff suppressed because it is too large Load Diff