mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
Add support for restricting detection classes (#45)
* Add support for restricting detection classes in `Options`
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "usls"
|
||||
version = "0.0.17"
|
||||
version = "0.0.18"
|
||||
edition = "2021"
|
||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||
repository = "https://github.com/jamjamjon/usls"
|
||||
|
@ -160,6 +160,8 @@ fn main() -> Result<()> {
|
||||
// .with_names(&COCO_CLASS_NAMES_80)
|
||||
// .with_names2(&COCO_KEYPOINTS_17)
|
||||
.with_find_contours(!args.no_contours) // find contours or not
|
||||
.exclude_classes(&[0])
|
||||
// .retain_classes(&[0, 5])
|
||||
.with_profile(args.profile);
|
||||
|
||||
// build model
|
||||
|
@ -48,6 +48,8 @@ pub struct Options {
|
||||
pub sam_kind: Option<SamKind>,
|
||||
pub use_low_res_mask: Option<bool>,
|
||||
pub sapiens_task: Option<SapiensTask>,
|
||||
pub classes_excluded: Vec<isize>,
|
||||
pub classes_retained: Vec<isize>,
|
||||
}
|
||||
|
||||
impl Default for Options {
|
||||
@ -88,6 +90,8 @@ impl Default for Options {
|
||||
use_low_res_mask: None,
|
||||
sapiens_task: None,
|
||||
task: Task::Untitled,
|
||||
classes_excluded: vec![],
|
||||
classes_retained: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -276,4 +280,16 @@ impl Options {
|
||||
self.iiixs.push(Iiix::from((i, ii, x)));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
|
||||
self.classes_retained.clear();
|
||||
self.classes_excluded.extend_from_slice(xs);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
|
||||
self.classes_excluded.clear();
|
||||
self.classes_retained.extend_from_slice(xs);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,8 @@ pub struct YOLO {
|
||||
layout: YOLOPreds,
|
||||
find_contours: bool,
|
||||
version: Option<YOLOVersion>,
|
||||
classes_excluded: Vec<isize>,
|
||||
classes_retained: Vec<isize>,
|
||||
}
|
||||
|
||||
impl Vision for YOLO {
|
||||
@ -157,6 +159,10 @@ impl Vision for YOLO {
|
||||
let kconfs = DynConf::new(&options.kconfs, nk);
|
||||
let iou = options.iou.unwrap_or(0.45);
|
||||
|
||||
// Classes excluded and retained
|
||||
let classes_excluded = options.classes_excluded;
|
||||
let classes_retained = options.classes_retained;
|
||||
|
||||
// Summary
|
||||
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
|
||||
|
||||
@ -179,6 +185,8 @@ impl Vision for YOLO {
|
||||
layout,
|
||||
version,
|
||||
find_contours: options.find_contours,
|
||||
classes_excluded,
|
||||
classes_retained,
|
||||
})
|
||||
}
|
||||
|
||||
@ -276,7 +284,19 @@ impl Vision for YOLO {
|
||||
}
|
||||
};
|
||||
|
||||
// filtering
|
||||
// filtering by class id
|
||||
if !self.classes_excluded.is_empty()
|
||||
&& self.classes_excluded.contains(&(class_id as isize))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
if !self.classes_retained.is_empty()
|
||||
&& !self.classes_retained.contains(&(class_id as isize))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// filtering by conf
|
||||
if confidence < self.confs[class_id] {
|
||||
return None;
|
||||
}
|
||||
|
Reference in New Issue
Block a user