mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Add YOLOv11
* Add YOLOv11
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "usls"
|
||||
version = "0.0.16"
|
||||
version = "0.0.17"
|
||||
edition = "2021"
|
||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||
repository = "https://github.com/jamjamjon/usls"
|
||||
|
@ -24,45 +24,41 @@
|
||||
## Quick Start
|
||||
```Shell
|
||||
|
||||
# customized
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
|
||||
|
||||
# Classify
|
||||
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5
|
||||
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8
|
||||
cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
|
||||
cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
|
||||
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
|
||||
|
||||
# Detect
|
||||
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5
|
||||
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6
|
||||
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7
|
||||
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8
|
||||
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9
|
||||
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10
|
||||
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
|
||||
cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
|
||||
cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
|
||||
cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
|
||||
cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
|
||||
cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
|
||||
cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
|
||||
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world <local file>
|
||||
|
||||
# Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
|
||||
|
||||
# Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
|
||||
cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM <local file>
|
||||
|
||||
# Obb
|
||||
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
|
||||
cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
|
||||
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
|
||||
```
|
||||
|
||||
<details close>
|
||||
<summary>other options</summary>
|
||||
|
||||
`--source` to specify the input images
|
||||
`--model` to specify the ONNX model
|
||||
`--width --height` to specify the input resolution
|
||||
`--nc` to specify the number of model's classes
|
||||
`--plot` to annotate with inference results
|
||||
`--profile` to profile
|
||||
`--cuda --trt --coreml --device_id` to select device
|
||||
`--half` to use float16 when using TensorRT EP
|
||||
|
||||
</details>
|
||||
**`cargo run -r --example yolo -- --help` for more options**
|
||||
|
||||
|
||||
## YOLOs configs with `Options`
|
||||
@ -74,6 +70,8 @@ cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
|
||||
let options = Options::default()
|
||||
.with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
|
||||
.with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb
|
||||
// .with_nc(80)
|
||||
// .with_names(&COCO_CLASS_NAMES_80)
|
||||
.with_model("xxxx.onnx")?;
|
||||
|
||||
```
|
||||
@ -140,7 +138,7 @@ let options = Options::default()
|
||||
</details>
|
||||
|
||||
<details close>
|
||||
<summary>YOLOv8</summary>
|
||||
<summary>YOLOv8, YOLOv11</summary>
|
||||
|
||||
```Shell
|
||||
pip install -U ultralytics
|
||||
|
@ -2,188 +2,160 @@ use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
use usls::{
|
||||
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
|
||||
COCO_SKELETONS_16,
|
||||
models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
|
||||
YOLOVersion, COCO_SKELETONS_16,
|
||||
};
|
||||
|
||||
#[derive(Parser, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Path to the model
|
||||
#[arg(long)]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Input source path
|
||||
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
|
||||
pub source: String,
|
||||
|
||||
/// YOLO Task
|
||||
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
|
||||
pub task: YOLOTask,
|
||||
|
||||
/// YOLO Version
|
||||
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
|
||||
pub ver: YOLOVersion,
|
||||
|
||||
/// YOLO Scale
|
||||
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
|
||||
pub scale: YOLOScale,
|
||||
|
||||
/// Batch size
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Minimum input width
|
||||
#[arg(long, default_value_t = 224)]
|
||||
pub width_min: isize,
|
||||
|
||||
/// Input width
|
||||
#[arg(long, default_value_t = 640)]
|
||||
pub width: isize,
|
||||
|
||||
#[arg(long, default_value_t = 800)]
|
||||
/// Maximum input width
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
pub width_max: isize,
|
||||
|
||||
/// Minimum input height
|
||||
#[arg(long, default_value_t = 224)]
|
||||
pub height_min: isize,
|
||||
|
||||
/// Input height
|
||||
#[arg(long, default_value_t = 640)]
|
||||
pub height: isize,
|
||||
|
||||
#[arg(long, default_value_t = 800)]
|
||||
/// Maximum input height
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
pub height_max: isize,
|
||||
|
||||
/// Number of classes
|
||||
#[arg(long, default_value_t = 80)]
|
||||
pub nc: usize,
|
||||
|
||||
/// Class confidence
|
||||
#[arg(long)]
|
||||
pub confs: Vec<f32>,
|
||||
|
||||
/// Enable TensorRT support
|
||||
#[arg(long)]
|
||||
pub trt: bool,
|
||||
|
||||
/// Enable CUDA support
|
||||
#[arg(long)]
|
||||
pub cuda: bool,
|
||||
|
||||
#[arg(long)]
|
||||
pub half: bool,
|
||||
|
||||
/// Enable CoreML support
|
||||
#[arg(long)]
|
||||
pub coreml: bool,
|
||||
|
||||
/// Use TensorRT half precision
|
||||
#[arg(long)]
|
||||
pub half: bool,
|
||||
|
||||
/// Device ID to use
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub device_id: usize,
|
||||
|
||||
/// Enable performance profiling
|
||||
#[arg(long)]
|
||||
pub profile: bool,
|
||||
|
||||
#[arg(long)]
|
||||
pub no_plot: bool,
|
||||
|
||||
/// Disable contour drawing
|
||||
#[arg(long)]
|
||||
pub no_contours: bool,
|
||||
|
||||
/// Show result
|
||||
#[arg(long)]
|
||||
pub view: bool,
|
||||
|
||||
/// Do not save output
|
||||
#[arg(long)]
|
||||
pub nosave: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// build options
|
||||
let options = Options::default();
|
||||
// path
|
||||
let path = args.model.unwrap_or({
|
||||
format!(
|
||||
"yolo/{}-{}-{}.onnx",
|
||||
args.ver.name(),
|
||||
args.scale.name(),
|
||||
args.task.name()
|
||||
)
|
||||
});
|
||||
|
||||
// version & task
|
||||
let (options, saveout) = match args.ver {
|
||||
YOLOVersion::V5 => match args.task {
|
||||
YOLOTask::Classify => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Classify",
|
||||
),
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Detect",
|
||||
),
|
||||
YOLOTask::Segment => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Segment",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V6 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options
|
||||
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
|
||||
.with_nc(args.nc),
|
||||
"YOLOv6-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V7 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options
|
||||
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
|
||||
.with_nc(args.nc),
|
||||
"YOLOv7-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V8 => match args.task {
|
||||
YOLOTask::Classify => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Classify",
|
||||
),
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Detect",
|
||||
),
|
||||
YOLOTask::Segment => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Segment",
|
||||
),
|
||||
YOLOTask::Pose => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Pose",
|
||||
),
|
||||
YOLOTask::Obb => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Obb",
|
||||
),
|
||||
},
|
||||
YOLOVersion::V9 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
|
||||
"YOLOv9-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V10 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
|
||||
"YOLOv10-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::RTDETR => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
|
||||
"RTDETR",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
};
|
||||
|
||||
let options = options
|
||||
.with_yolo_version(args.ver)
|
||||
.with_yolo_task(args.task);
|
||||
// saveout
|
||||
let saveout = format!(
|
||||
"{}-{}-{}",
|
||||
args.ver.name(),
|
||||
args.scale.name(),
|
||||
args.task.name()
|
||||
);
|
||||
|
||||
// device
|
||||
let options = if args.cuda {
|
||||
options.with_cuda(args.device_id)
|
||||
let device = if args.cuda {
|
||||
Device::Cuda(args.device_id)
|
||||
} else if args.trt {
|
||||
let options = options.with_trt(args.device_id);
|
||||
if args.half {
|
||||
options.with_trt_fp16(true)
|
||||
} else {
|
||||
options
|
||||
}
|
||||
Device::Trt(args.device_id)
|
||||
} else if args.coreml {
|
||||
options.with_coreml(args.device_id)
|
||||
Device::CoreML(args.device_id)
|
||||
} else {
|
||||
options.with_cpu()
|
||||
Device::Cpu(args.device_id)
|
||||
};
|
||||
let options = options
|
||||
|
||||
// build options
|
||||
let options = Options::new()
|
||||
.with_model(&path)?
|
||||
.with_yolo_version(args.ver)
|
||||
.with_yolo_task(args.task)
|
||||
.with_device(device)
|
||||
.with_trt_fp16(args.half)
|
||||
.with_ixx(0, 0, (1, args.batch_size as _, 4).into())
|
||||
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
|
||||
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
|
||||
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
|
||||
.with_confs(if args.confs.is_empty() {
|
||||
&[0.2, 0.15]
|
||||
} else {
|
||||
&args.confs
|
||||
})
|
||||
.with_nc(args.nc)
|
||||
// .with_names(&COCO_CLASS_NAMES_80)
|
||||
.with_names2(&COCO_KEYPOINTS_17)
|
||||
// .with_names2(&COCO_KEYPOINTS_17)
|
||||
.with_find_contours(!args.no_contours) // find contours or not
|
||||
.with_profile(args.profile);
|
||||
|
||||
// build model
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// build dataloader
|
||||
@ -194,16 +166,54 @@ fn main() -> Result<()> {
|
||||
// build annotator
|
||||
let annotator = Annotator::default()
|
||||
.with_skeletons(&COCO_SKELETONS_16)
|
||||
.with_bboxes_thickness(4)
|
||||
.without_masks(true) // No masks plotting when doing segment task.
|
||||
.with_saveout(saveout);
|
||||
.with_bboxes_thickness(3)
|
||||
.with_keypoints_name(false) // Enable keypoints names
|
||||
.with_saveout_subs(&["YOLO"])
|
||||
.with_saveout(&saveout);
|
||||
|
||||
// build viewer
|
||||
let mut viewer = if args.view {
|
||||
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// run & annotate
|
||||
for (xs, _paths) in dl {
|
||||
// let ys = model.run(&xs)?; // way one
|
||||
let ys = model.forward(&xs, args.profile)?; // way two
|
||||
if !args.no_plot {
|
||||
annotator.annotate(&xs, &ys);
|
||||
let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
|
||||
|
||||
// show image
|
||||
match &mut viewer {
|
||||
Some(viewer) => viewer.imshow(&images_plotted)?,
|
||||
None => continue,
|
||||
}
|
||||
|
||||
// check out window and key event
|
||||
match &mut viewer {
|
||||
Some(viewer) => {
|
||||
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
|
||||
// write video
|
||||
if !args.nosave {
|
||||
match &mut viewer {
|
||||
Some(viewer) => viewer.write_batch(&images_plotted)?,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish video write
|
||||
if !args.nosave {
|
||||
if let Some(viewer) = &mut viewer {
|
||||
viewer.finish_write()?;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,8 +20,8 @@ pub struct YOLO {
|
||||
confs: DynConf,
|
||||
kconfs: DynConf,
|
||||
iou: f32,
|
||||
names: Option<Vec<String>>,
|
||||
names_kpt: Option<Vec<String>>,
|
||||
names: Vec<String>,
|
||||
names_kpt: Vec<String>,
|
||||
task: YOLOTask,
|
||||
layout: YOLOPreds,
|
||||
find_contours: bool,
|
||||
@ -64,27 +64,26 @@ impl Vision for YOLO {
|
||||
Some(task) => match task {
|
||||
YOLOTask::Classify => match ver {
|
||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
|
||||
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Detect => match ver {
|
||||
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()),
|
||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
||||
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
||||
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
|
||||
YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
||||
YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
||||
}
|
||||
YOLOTask::Pose => match ver {
|
||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
|
||||
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Segment => match ver {
|
||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Obb => match ver {
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
}
|
||||
@ -97,42 +96,63 @@ impl Vision for YOLO {
|
||||
|
||||
let task = task.unwrap_or(layout.task());
|
||||
|
||||
// The number of classes & Class names
|
||||
let mut names = options.names.or(Self::fetch_names(&engine));
|
||||
let nc = match options.nc {
|
||||
Some(nc) => {
|
||||
match &names {
|
||||
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()),
|
||||
Some(names) => {
|
||||
assert_eq!(
|
||||
nc,
|
||||
// Class names: user-defined.or(parsed)
|
||||
let names_parsed = Self::fetch_names(&engine);
|
||||
let names = match names_parsed {
|
||||
Some(names_parsed) => match options.names {
|
||||
Some(names) => {
|
||||
if names.len() == names_parsed.len() {
|
||||
Some(names)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
|
||||
names_parsed.len(),
|
||||
names.len(),
|
||||
"The length of `nc` and `class names` is not equal."
|
||||
);
|
||||
}
|
||||
}
|
||||
nc
|
||||
}
|
||||
None => match &names {
|
||||
Some(names) => names.len(),
|
||||
None => panic!(
|
||||
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`"
|
||||
None => Some(names_parsed),
|
||||
},
|
||||
None => options.names,
|
||||
};
|
||||
|
||||
// nc: names.len().or(options.nc)
|
||||
let nc = match &names {
|
||||
Some(names) => names.len(),
|
||||
None => match options.nc {
|
||||
Some(nc) => nc,
|
||||
None => anyhow::bail!(
|
||||
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
|
||||
),
|
||||
}
|
||||
};
|
||||
|
||||
// Class names
|
||||
let names = match names {
|
||||
None => Self::n2s(nc),
|
||||
Some(names) => names,
|
||||
};
|
||||
|
||||
// Keypoint names & nk
|
||||
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
|
||||
None => (0, vec![]),
|
||||
Some(nk) => match options.names2 {
|
||||
Some(names) => {
|
||||
if names.len() == nk {
|
||||
(nk, names)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
|
||||
names.len(),
|
||||
nk,
|
||||
);
|
||||
}
|
||||
}
|
||||
None => (nk, Self::n2s(nk)),
|
||||
},
|
||||
};
|
||||
|
||||
// Keypoints names
|
||||
let names_kpt = options.names2;
|
||||
|
||||
// The number of keypoints
|
||||
let nk = engine
|
||||
.try_fetch("kpt_shape")
|
||||
.map(|kpt_string| {
|
||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||
let caps = re.captures(&kpt_string).unwrap();
|
||||
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
||||
})
|
||||
.unwrap_or(0_usize);
|
||||
// Confs & Iou
|
||||
let confs = DynConf::new(&options.confs, nc);
|
||||
let kconfs = DynConf::new(&options.kconfs, nk);
|
||||
let iou = options.iou.unwrap_or(0.45);
|
||||
@ -140,6 +160,7 @@ impl Vision for YOLO {
|
||||
// Summary
|
||||
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
|
||||
|
||||
// dry run
|
||||
engine.dry_run()?;
|
||||
|
||||
Ok(Self {
|
||||
@ -219,10 +240,8 @@ impl Vision for YOLO {
|
||||
slice_clss.into_owned()
|
||||
};
|
||||
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
|
||||
if let Some(names) = &self.names {
|
||||
probs =
|
||||
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
||||
}
|
||||
probs = probs
|
||||
.with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
||||
|
||||
return Some(y.with_probs(&probs));
|
||||
}
|
||||
@ -325,9 +344,7 @@ impl Vision for YOLO {
|
||||
)
|
||||
.with_confidence(confidence)
|
||||
.with_id(class_id as isize);
|
||||
if let Some(names) = &self.names {
|
||||
mbr = mbr.with_name(&names[class_id]);
|
||||
}
|
||||
mbr = mbr.with_name(&self.names[class_id]);
|
||||
|
||||
(None, Some(mbr))
|
||||
}
|
||||
@ -337,9 +354,7 @@ impl Vision for YOLO {
|
||||
.with_confidence(confidence)
|
||||
.with_id(class_id as isize)
|
||||
.with_id_born(i as isize);
|
||||
if let Some(names) = &self.names {
|
||||
bbox = bbox.with_name(&names[class_id]);
|
||||
}
|
||||
bbox = bbox.with_name(&self.names[class_id]);
|
||||
|
||||
(Some(bbox), None)
|
||||
}
|
||||
@ -394,9 +409,7 @@ impl Vision for YOLO {
|
||||
ky.max(0.0f32).min(image_height),
|
||||
);
|
||||
|
||||
if let Some(names) = &self.names_kpt {
|
||||
kpt = kpt.with_name(&names[i]);
|
||||
}
|
||||
kpt = kpt.with_name(&self.names_kpt[i]);
|
||||
kpt
|
||||
}
|
||||
})
|
||||
@ -505,16 +518,16 @@ impl Vision for YOLO {
|
||||
}
|
||||
|
||||
impl YOLO {
|
||||
pub fn batch(&self) -> isize {
|
||||
self.batch.opt() as _
|
||||
pub fn batch(&self) -> usize {
|
||||
self.batch.opt()
|
||||
}
|
||||
|
||||
pub fn width(&self) -> isize {
|
||||
self.width.opt() as _
|
||||
pub fn width(&self) -> usize {
|
||||
self.width.opt()
|
||||
}
|
||||
|
||||
pub fn height(&self) -> isize {
|
||||
self.height.opt() as _
|
||||
pub fn height(&self) -> usize {
|
||||
self.height.opt()
|
||||
}
|
||||
|
||||
pub fn version(&self) -> Option<&YOLOVersion> {
|
||||
@ -541,4 +554,16 @@ impl YOLO {
|
||||
names_
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
|
||||
engine.try_fetch("kpt_shape").map(|s| {
|
||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||
let caps = re.captures(&s).unwrap();
|
||||
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn n2s(n: usize) -> Vec<String> {
|
||||
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,28 @@ pub enum YOLOTask {
|
||||
Obb,
|
||||
}
|
||||
|
||||
impl YOLOTask {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::Classify => "cls".to_string(),
|
||||
Self::Detect => "det".to_string(),
|
||||
Self::Pose => "pose".to_string(),
|
||||
Self::Segment => "seg".to_string(),
|
||||
Self::Obb => "obb".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name_detailed(&self) -> String {
|
||||
match self {
|
||||
Self::Classify => "image-classification".to_string(),
|
||||
Self::Detect => "object-detection".to_string(),
|
||||
Self::Pose => "pose-estimation".to_string(),
|
||||
Self::Segment => "instance-segment".to_string(),
|
||||
Self::Obb => "oriented-object-detection".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||
pub enum YOLOVersion {
|
||||
V5,
|
||||
@ -17,9 +39,54 @@ pub enum YOLOVersion {
|
||||
V8,
|
||||
V9,
|
||||
V10,
|
||||
V11,
|
||||
RTDETR,
|
||||
}
|
||||
|
||||
impl YOLOVersion {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::V5 => "v5".to_string(),
|
||||
Self::V6 => "v6".to_string(),
|
||||
Self::V7 => "v7".to_string(),
|
||||
Self::V8 => "v8".to_string(),
|
||||
Self::V9 => "v9".to_string(),
|
||||
Self::V10 => "v10".to_string(),
|
||||
Self::V11 => "v11".to_string(),
|
||||
Self::RTDETR => "rtdetr".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||
pub enum YOLOScale {
|
||||
N,
|
||||
T,
|
||||
B,
|
||||
S,
|
||||
M,
|
||||
L,
|
||||
C,
|
||||
E,
|
||||
X,
|
||||
}
|
||||
|
||||
impl YOLOScale {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::N => "n".to_string(),
|
||||
Self::T => "t".to_string(),
|
||||
Self::S => "s".to_string(),
|
||||
Self::B => "b".to_string(),
|
||||
Self::M => "m".to_string(),
|
||||
Self::L => "l".to_string(),
|
||||
Self::C => "c".to_string(),
|
||||
Self::E => "e".to_string(),
|
||||
Self::X => "x".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum BoxType {
|
||||
/// 1
|
||||
|
Reference in New Issue
Block a user