mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 23:55:38 +00:00
Add query method for dinov2 and adjust DataLoader
This commit is contained in:
@ -1,12 +1,13 @@
|
||||
[package]
|
||||
name = "usls"
|
||||
version = "0.0.1"
|
||||
version = "0.0.2"
|
||||
edition = "2021"
|
||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||
repository = "https://github.com/jamjamjon/usls"
|
||||
authors = ["Jamjamjon <jamjamjon.usls@gmail.com>"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
exclude = ["assets/*", "examples/*"]
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
64
README.md
64
README.md
@ -2,7 +2,6 @@
|
||||
|
||||
A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`.
|
||||
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Example | CUDA(f32) | CUDA(f16) | TensorRT(f32) | TensorRT(f16) |
|
||||
@ -11,7 +10,7 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
|
||||
| YOLOv8-pose | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
|
||||
| YOLOv8-classification | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
|
||||
| YOLOv8-segmentation | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ |
|
||||
| YOLOv8-OBB | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | |
|
||||
| YOLOv8-OBB | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** |
|
||||
| YOLOv9 | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ |
|
||||
| RT-DETR | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ |
|
||||
| FastSAM | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ |
|
||||
@ -19,17 +18,18 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp
|
||||
| DINOv2 | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ |
|
||||
| CLIP | [demo](examples/clip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
|
||||
| BLIP | [demo](examples/blip) | ✅ | ✅ | ✅ visual<br />❌ textual | ✅ visual<br />❌ textual |
|
||||
| OCR(DB, SVTR) | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | |
|
||||
| OCR(DB, SVTR) | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** |
|
||||
|
||||
## Solution Models
|
||||
|
||||
Additionally, this repo also provides some solution models such as pedestrian `fall detection`, `head detection`, `trash detection`, and more.
|
||||
|
||||
| Model | Example | Result |
|
||||
| :---------------------------: | :------------------------------: | :--------------------------------------------------------------------------: |
|
||||
| face-landmark detection | [demo](examples/yolov8-face) | <img src="./examples/yolov8-face/demo.jpg" width="400" height="300"> |
|
||||
| head detection | [demo](examples/yolov8-head) | <img src="./examples/yolov8-head/demo.jpg" width="400" height="300"> |
|
||||
| fall detection | [demo](examples/yolov8-falldown) | <img src="./examples/yolov8-falldown/demo.jpg" width="400" height="300"> |
|
||||
| trash detection | [demo](examples/yolov8-plastic-bag) | <img src="./examples/yolov8-trash/demo.jpg" width="400" height="260"> |
|
||||
| Model | Example |
|
||||
| :---------------------: | :------------------------------: |
|
||||
| face-landmark detection | [demo](examples/yolov8-face) |
|
||||
| head detection | [demo](examples/yolov8-head) |
|
||||
| fall detection | [demo](examples/yolov8-falldown) |
|
||||
| trash detection | [demo](examples/yolov8-plastic-bag) |
|
||||
|
||||
## Demo
|
||||
|
||||
@ -51,46 +51,56 @@ check **[ort guide](https://ort.pyke.io/setup/linking)**
|
||||
```shell
|
||||
export ORT_DYLIB_PATH=/Users/qweasd/Desktop/onnxruntime-osx-arm64-1.17.1/lib/libonnxruntime.1.17.1.dylib
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 2. Add `usls` as a dependency to your project's `Cargo.toml:`
|
||||
#### 2. Add `usls` as a dependency to your project's `Cargo.toml`
|
||||
|
||||
```
|
||||
[dependencies]
|
||||
usls = "0.0.1"
|
||||
```shell
|
||||
cargo add --git https://github.com/jamjamjon/usls
|
||||
|
||||
# or
|
||||
cargo add usls
|
||||
```
|
||||
|
||||
#### 3. Set model `Options` and build `model`, then you're ready to go.
|
||||
|
||||
#### 3. Set `Options` and build model
|
||||
```Rust
|
||||
2use usls::{models::YOLO, Options};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// 1.build model
|
||||
let options = Options::default()
|
||||
let options = Options::default()
|
||||
.with_model("../models/yolov8m-seg-dyn-f16.onnx")
|
||||
.with_trt(0) // using cuda(0) by default
|
||||
// when model with dynamic shapes
|
||||
// when model with dynamic shapes
|
||||
.with_i00((1, 2, 4).into()) // dynamic batch
|
||||
.with_i02((416, 640, 800).into()) // dynamic height
|
||||
.with_i03((416, 640, 800).into()) // dynamic width
|
||||
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
|
||||
.with_dry_run(3)
|
||||
.with_saveout("YOLOv8"); // save results
|
||||
let mut model = YOLO::new(&options)?;
|
||||
let mut model = YOLO::new(&options)?;
|
||||
```
|
||||
|
||||
// 2.build dataloader
|
||||
let dl = DataLoader::default()
|
||||
#### 4. Prepare inputs, and then you're ready to go
|
||||
|
||||
- Build `DataLoader` to load images
|
||||
|
||||
```Rust
|
||||
let dl = DataLoader::default()
|
||||
.with_batch(model.batch.opt as usize)
|
||||
.load("./assets/")?;
|
||||
|
||||
// 3.run
|
||||
for (xs, _paths) in dl {
|
||||
for (xs, _paths) in dl {
|
||||
let _y = model.run(&xs)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
- Or simply read one image
|
||||
|
||||
```Rust
|
||||
let x = DataLoader::try_read("./assets/bus.jpg")?;
|
||||
let _y = model.run(&[x])?;
|
||||
```
|
||||
|
||||
|
||||
## Script: converte ONNX model from `float32` to `float16`
|
||||
|
||||
```python
|
||||
|
@ -16,7 +16,7 @@ cargo run -r --example dinov2
|
||||
|
||||
[dinov2-b14](https://github.com/jamjamjon/assets/releases/download/v0.0.1/dinov2-b14.onnx)
|
||||
[dinov2-b14-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/dinov2-b14-dyn.onnx)
|
||||
[dinov2-b14-dyn-f16](https://github.com/jamjamjon/assets/releases/download/v0.0.1/dinov2-b14-dyn-f16.onnx)
|
||||
|
||||
|
||||
### 2. Specify the ONNX model path in `main.rs`
|
||||
|
||||
@ -24,14 +24,6 @@ cargo run -r --example dinov2
|
||||
let options = Options::default()
|
||||
.with_model("ONNX_PATH") // <= modify this
|
||||
.with_profile(false);
|
||||
|
||||
// build index
|
||||
let options = IndexOptions {
|
||||
dimensions: 384, // 768 for vitb; 384 for vits
|
||||
metric: MetricKind::L2sq,
|
||||
quantization: ScalarKind::F16,
|
||||
..Default::default()
|
||||
};
|
||||
```
|
||||
|
||||
### 3. Then, run
|
||||
@ -43,8 +35,7 @@ cargo run -r --example dinov2
|
||||
## Results
|
||||
|
||||
```shell
|
||||
Top-1 distance: 0.0 => "./examples/dinov2/images/bus.jpg"
|
||||
Top-2 distance: 1.8332717 => "./examples/dinov2/images/dog.png"
|
||||
Top-3 distance: 1.9672602 => "./examples/dinov2/images/cat.png"
|
||||
Top-4 distance: 1.978817 => "./examples/dinov2/images/carrot.jpg"
|
||||
Top-1 0.0000000 /home/qweasd/Desktop/usls/examples/dinov2/images/bus.jpg
|
||||
Top-2 1.9059424 /home/qweasd/Desktop/usls/examples/dinov2/images/1.jpg
|
||||
Top-3 1.9736203 /home/qweasd/Desktop/usls/examples/dinov2/images/2.jpg
|
||||
```
|
||||
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 66 KiB |
Before Width: | Height: | Size: 44 KiB After Width: | Height: | Size: 44 KiB |
@ -1,55 +1,36 @@
|
||||
use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
|
||||
use usls::{models::Dinov2, DataLoader, Options};
|
||||
use usls::{models::Dinov2, Metric, Options};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// build model
|
||||
let options = Options::default()
|
||||
.with_model("../models/dinov2-s14-dyn-f16.onnx")
|
||||
// .with_model("../models/dinov2-b14-dyn.onnx")
|
||||
.with_i00((1, 1, 1).into())
|
||||
.with_i02((224, 224, 224).into())
|
||||
.with_i03((224, 224, 224).into());
|
||||
let mut model = Dinov2::new(&options)?;
|
||||
|
||||
// build dataloader
|
||||
let dl = DataLoader::default()
|
||||
.with_batch(model.batch.opt as usize)
|
||||
.load("./examples/dinov2/images")?;
|
||||
// query from vector
|
||||
let ys = model.query_from_vec(
|
||||
"./assets/bus.jpg",
|
||||
&[
|
||||
"./examples/dinov2/images/bus.jpg",
|
||||
"./examples/dinov2/images/1.jpg",
|
||||
"./examples/dinov2/images/2.jpg",
|
||||
],
|
||||
Metric::L2,
|
||||
)?;
|
||||
|
||||
// load query
|
||||
let query = image::io::Reader::open("./assets/bus.jpg")?.decode()?;
|
||||
let query = model.run(&[query])?;
|
||||
// or query from folder
|
||||
// let ys = model.query_from_folder("./assets/bus.jpg", "./examples/dinov2/images", Metric::IP)?;
|
||||
|
||||
// build index
|
||||
let options = IndexOptions {
|
||||
dimensions: 384, // 768 for vitb; 384 for vits
|
||||
metric: MetricKind::L2sq,
|
||||
quantization: ScalarKind::F16,
|
||||
..Default::default()
|
||||
};
|
||||
let index = usearch::new_index(&options)?;
|
||||
index.reserve(dl.clone().count())?;
|
||||
|
||||
// load feats
|
||||
for (idx, (image, _path)) in dl.clone().enumerate() {
|
||||
let y = model.run(&image)?;
|
||||
index.add(idx as u64, &y.into_raw_vec())?;
|
||||
}
|
||||
|
||||
// output
|
||||
let topk = 10;
|
||||
let matches = index.search(&query.into_raw_vec(), topk)?;
|
||||
let paths = dl.paths;
|
||||
for (idx, (k, score)) in matches
|
||||
.keys
|
||||
.into_iter()
|
||||
.zip(matches.distances.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
// results
|
||||
for (i, y) in ys.iter().enumerate() {
|
||||
println!(
|
||||
"Top-{} distance: {:?} => {:?}",
|
||||
idx + 1,
|
||||
score,
|
||||
paths[k as usize]
|
||||
"Top-{:<3}{:.7} {}",
|
||||
i + 1,
|
||||
y.1,
|
||||
y.2.canonicalize()?.display()
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -12,11 +12,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.with_profile(false);
|
||||
let mut model = YOLO::new(&options)?;
|
||||
|
||||
// build dataloader
|
||||
let mut dl = DataLoader::default().load("./assets/kids.jpg")?;
|
||||
// load image
|
||||
let x = DataLoader::try_read("./assets/kids.jpg")?;
|
||||
|
||||
// run
|
||||
model.run(&dl.next().unwrap().0)?;
|
||||
let _y = model.run(&[x])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -4,22 +4,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// 1.build model
|
||||
let options = Options::default()
|
||||
.with_model("../models/yolov8m-dyn-f16.onnx")
|
||||
.with_trt(0) // cuda by default
|
||||
.with_fp16(true)
|
||||
// .with_trt(0) // cuda by default
|
||||
// .with_fp16(true)
|
||||
.with_i00((1, 1, 4).into())
|
||||
.with_i02((416, 640, 800).into())
|
||||
.with_i03((416, 640, 800).into())
|
||||
.with_confs(&[0.4, 0.15]) // person: 0.4, others: 0.15
|
||||
.with_profile(true)
|
||||
.with_dry_run(5)
|
||||
.with_profile(false)
|
||||
.with_dry_run(3)
|
||||
.with_skeletons(&COCO_SKELETON_17)
|
||||
.with_saveout("YOLOv8");
|
||||
let mut model = YOLO::new(&options)?;
|
||||
|
||||
// 2.build dataloader
|
||||
let dl = DataLoader::default()
|
||||
.with_batch(1)
|
||||
.load("./assets/bus.jpg")?;
|
||||
let dl = DataLoader::default().with_batch(1).load("./assets")?;
|
||||
|
||||
// 3.run
|
||||
for (xs, _paths) in dl {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::{CHECK_MARK, CROSS_MARK, SAFE_CROSS_MARK};
|
||||
use anyhow::Result;
|
||||
use crate::{CHECK_MARK, SAFE_CROSS_MARK};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use image::DynamicImage;
|
||||
use std::collections::VecDeque;
|
||||
use std::path::{Path, PathBuf};
|
||||
@ -7,10 +7,9 @@ use walkdir::{DirEntry, WalkDir};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DataLoader {
|
||||
// source could be single image, folder with images (TODO: video, stream)
|
||||
pub source: PathBuf,
|
||||
pub batch: usize,
|
||||
// source could be single image path, folder with images (TODO: video, stream)
|
||||
pub recursive: bool,
|
||||
pub batch: usize,
|
||||
pub paths: VecDeque<PathBuf>,
|
||||
}
|
||||
|
||||
@ -25,25 +24,14 @@ impl Iterator for DataLoader {
|
||||
let mut yps: Vec<PathBuf> = Vec::new();
|
||||
loop {
|
||||
let path = self.paths.pop_front().unwrap();
|
||||
match image::io::Reader::open(&path) {
|
||||
match Self::try_read(&path) {
|
||||
Err(err) => {
|
||||
println!(
|
||||
"{SAFE_CROSS_MARK} Faild to load image: {:?} -> {:?}",
|
||||
self.paths[0], err
|
||||
);
|
||||
}
|
||||
Ok(p) => match p.decode() {
|
||||
Err(err) => {
|
||||
println!(
|
||||
"{SAFE_CROSS_MARK} Fail to load image: {:?} -> {:?}",
|
||||
self.paths[0], err
|
||||
);
|
||||
println!("{SAFE_CROSS_MARK} {err}");
|
||||
}
|
||||
Ok(x) => {
|
||||
yis.push(x);
|
||||
yps.push(path);
|
||||
}
|
||||
},
|
||||
}
|
||||
if self.paths.is_empty() || yis.len() == self.batch {
|
||||
break;
|
||||
@ -59,14 +47,13 @@ impl Default for DataLoader {
|
||||
Self {
|
||||
batch: 1,
|
||||
recursive: false,
|
||||
source: Default::default(),
|
||||
paths: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DataLoader {
|
||||
pub fn load<P: AsRef<Path>>(&self, source: P) -> Result<Self> {
|
||||
pub fn load<P: AsRef<Path>>(&mut self, source: P) -> Result<Self> {
|
||||
let source = source.as_ref();
|
||||
let mut paths = VecDeque::new();
|
||||
|
||||
@ -88,18 +75,29 @@ impl DataLoader {
|
||||
}
|
||||
}
|
||||
// s if s.starts_with("rtsp://") || s.starts_with("rtmp://") || s.starts_with("http://")|| s.starts_with("https://") => todo!(),
|
||||
s if !s.exists() => panic!("{CROSS_MARK} File not found: {s:?}"),
|
||||
s if !s.exists() => bail!("{s:?} Not Exists"),
|
||||
_ => todo!(),
|
||||
}
|
||||
println!("{CHECK_MARK} {} files found\n", &paths.len());
|
||||
let n_new = paths.len();
|
||||
self.paths.append(&mut paths);
|
||||
println!(
|
||||
"{CHECK_MARK} {n_new} files found ({} total)",
|
||||
self.paths.len()
|
||||
);
|
||||
Ok(Self {
|
||||
paths,
|
||||
source: source.into(),
|
||||
paths: self.paths.to_owned(),
|
||||
batch: self.batch,
|
||||
recursive: self.recursive,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn try_read<P: AsRef<Path>>(path: P) -> Result<DynamicImage> {
|
||||
image::io::Reader::open(&path)
|
||||
.map_err(|_| anyhow!("Failed to open image at {:?}", path.as_ref()))?
|
||||
.decode()
|
||||
.map_err(|_| anyhow!("Failed to decode image at {:?}", path.as_ref()))
|
||||
}
|
||||
|
||||
pub fn with_batch(mut self, x: usize) -> Self {
|
||||
self.batch = x;
|
||||
self
|
||||
@ -110,6 +108,10 @@ impl DataLoader {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn paths(&self) -> &VecDeque<PathBuf> {
|
||||
&self.paths
|
||||
}
|
||||
|
||||
fn _is_hidden(entry: &DirEntry) -> bool {
|
||||
entry
|
||||
.file_name()
|
||||
|
@ -6,6 +6,7 @@ mod dynconf;
|
||||
mod embedding;
|
||||
mod engine;
|
||||
mod keypoint;
|
||||
mod metric;
|
||||
mod min_opt_max;
|
||||
pub mod models;
|
||||
pub mod ops;
|
||||
@ -25,6 +26,7 @@ pub use dynconf::DynConf;
|
||||
pub use embedding::Embedding;
|
||||
pub use engine::OrtEngine;
|
||||
pub use keypoint::Keypoint;
|
||||
pub use metric::Metric;
|
||||
pub use min_opt_max::MinOptMax;
|
||||
pub use options::Options;
|
||||
pub use point::Point;
|
||||
|
6
src/metric.rs
Normal file
6
src/metric.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#[derive(Debug)]
|
||||
pub enum Metric {
|
||||
IP,
|
||||
Cos,
|
||||
L2,
|
||||
}
|
@ -1,7 +1,15 @@
|
||||
use crate::{ops, MinOptMax, Options, OrtEngine};
|
||||
use crate::{ops, DataLoader, Metric, MinOptMax, Options, OrtEngine};
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use ndarray::{Array, IxDyn};
|
||||
use std::path::PathBuf;
|
||||
use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Model {
|
||||
S,
|
||||
B,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Dinov2 {
|
||||
@ -9,6 +17,7 @@ pub struct Dinov2 {
|
||||
pub height: MinOptMax,
|
||||
pub width: MinOptMax,
|
||||
pub batch: MinOptMax,
|
||||
pub hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Dinov2 {
|
||||
@ -19,6 +28,15 @@ impl Dinov2 {
|
||||
engine.inputs_minoptmax()[0][2].to_owned(),
|
||||
engine.inputs_minoptmax()[0][3].to_owned(),
|
||||
);
|
||||
let which = match &options.onnx_path {
|
||||
s if s.contains("b14") => Model::B,
|
||||
s if s.contains("s14") => Model::S,
|
||||
_ => todo!(),
|
||||
};
|
||||
let hidden_size = match which {
|
||||
Model::S => 384,
|
||||
Model::B => 768,
|
||||
};
|
||||
engine.dry_run()?;
|
||||
|
||||
Ok(Self {
|
||||
@ -26,6 +44,7 @@ impl Dinov2 {
|
||||
height,
|
||||
width,
|
||||
batch,
|
||||
hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
@ -36,4 +55,96 @@ impl Dinov2 {
|
||||
let ys = ops::norm(&ys);
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn build_index(&self, metric: Metric) -> Result<usearch::Index> {
|
||||
let metric = match metric {
|
||||
Metric::IP => MetricKind::IP,
|
||||
Metric::L2 => MetricKind::L2sq,
|
||||
Metric::Cos => MetricKind::Cos,
|
||||
};
|
||||
let options = IndexOptions {
|
||||
metric,
|
||||
dimensions: self.hidden_size,
|
||||
quantization: ScalarKind::F16,
|
||||
..Default::default()
|
||||
};
|
||||
Ok(usearch::new_index(&options)?)
|
||||
}
|
||||
|
||||
pub fn query_from_folder(
|
||||
&mut self,
|
||||
qurey: &str,
|
||||
gallery: &str,
|
||||
metric: Metric,
|
||||
) -> Result<Vec<(usize, f32, PathBuf)>> {
|
||||
// load query
|
||||
let query = DataLoader::try_read(qurey)?;
|
||||
let query = self.run(&[query])?;
|
||||
|
||||
// build index & gallery
|
||||
let index = self.build_index(metric)?;
|
||||
let dl = DataLoader::default()
|
||||
.with_batch(self.batch.opt as usize)
|
||||
.load(gallery)?;
|
||||
let paths = dl.paths().to_owned();
|
||||
index.reserve(paths.len())?;
|
||||
|
||||
// load feats
|
||||
for (idx, (x, _path)) in dl.enumerate() {
|
||||
let y = self.run(&x)?;
|
||||
index.add(idx as u64, &y.into_raw_vec())?;
|
||||
}
|
||||
|
||||
// output
|
||||
let matches = index.search(&query.into_raw_vec(), index.size())?;
|
||||
let mut results: Vec<(usize, f32, PathBuf)> = Vec::new();
|
||||
matches
|
||||
.keys
|
||||
.into_iter()
|
||||
.zip(matches.distances)
|
||||
.for_each(|(k, score)| {
|
||||
results.push((k as usize, score, paths[k as usize].to_owned()));
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub fn query_from_vec(
|
||||
&mut self,
|
||||
qurey: &str,
|
||||
gallery: &[&str],
|
||||
metric: Metric,
|
||||
) -> Result<Vec<(usize, f32, PathBuf)>> {
|
||||
// load query
|
||||
let query = DataLoader::try_read(qurey)?;
|
||||
let query = self.run(&[query])?;
|
||||
|
||||
// build index & gallery
|
||||
let index = self.build_index(metric)?;
|
||||
index.reserve(gallery.len())?;
|
||||
let mut dl = DataLoader::default().with_batch(self.batch.opt as usize);
|
||||
gallery.iter().for_each(|x| {
|
||||
dl.load(x).unwrap();
|
||||
});
|
||||
|
||||
// load feats
|
||||
let paths = dl.paths().to_owned();
|
||||
for (idx, (x, _path)) in dl.enumerate() {
|
||||
let y = self.run(&x)?;
|
||||
index.add(idx as u64, &y.into_raw_vec())?;
|
||||
}
|
||||
|
||||
// output
|
||||
let matches = index.search(&query.into_raw_vec(), index.size())?;
|
||||
let mut results: Vec<(usize, f32, PathBuf)> = Vec::new();
|
||||
matches
|
||||
.keys
|
||||
.into_iter()
|
||||
.zip(matches.distances)
|
||||
.for_each(|(k, score)| {
|
||||
results.push((k as usize, score, paths[k as usize].to_owned()));
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user