Add support for restricting detection classes (#45)

* Add support for restricting detection classes in `Options`
This commit is contained in:
Jamjamjon
2024-10-05 17:49:08 +08:00
committed by GitHub
parent 0102c15687
commit 1d596383de
4 changed files with 40 additions and 2 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.17"
version = "0.0.18"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"

View File

@ -160,6 +160,8 @@ fn main() -> Result<()> {
// .with_names(&COCO_CLASS_NAMES_80)
// .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.exclude_classes(&[0])
// .retain_classes(&[0, 5])
.with_profile(args.profile);
// build model

View File

@ -48,6 +48,8 @@ pub struct Options {
pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>,
pub sapiens_task: Option<SapiensTask>,
pub classes_excluded: Vec<isize>,
pub classes_retained: Vec<isize>,
}
impl Default for Options {
@ -88,6 +90,8 @@ impl Default for Options {
use_low_res_mask: None,
sapiens_task: None,
task: Task::Untitled,
classes_excluded: vec![],
classes_retained: vec![],
}
}
}
@ -276,4 +280,16 @@ impl Options {
self.iiixs.push(Iiix::from((i, ii, x)));
self
}
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
self.classes_retained.clear();
self.classes_excluded.extend_from_slice(xs);
self
}
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
self.classes_excluded.clear();
self.classes_retained.extend_from_slice(xs);
self
}
}

View File

@ -26,6 +26,8 @@ pub struct YOLO {
layout: YOLOPreds,
find_contours: bool,
version: Option<YOLOVersion>,
classes_excluded: Vec<isize>,
classes_retained: Vec<isize>,
}
impl Vision for YOLO {
@ -157,6 +159,10 @@ impl Vision for YOLO {
let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45);
// Classes excluded and retained
let classes_excluded = options.classes_excluded;
let classes_retained = options.classes_retained;
// Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
@ -179,6 +185,8 @@ impl Vision for YOLO {
layout,
version,
find_contours: options.find_contours,
classes_excluded,
classes_retained,
})
}
@ -276,7 +284,19 @@ impl Vision for YOLO {
}
};
// filtering
// filtering by class id
if !self.classes_excluded.is_empty()
&& self.classes_excluded.contains(&(class_id as isize))
{
return None;
}
if !self.classes_retained.is_empty()
&& !self.classes_retained.contains(&(class_id as isize))
{
return None;
}
// filtering by conf
if confidence < self.confs[class_id] {
return None;
}