From 475a680703b0e4d272ff6f7cc9ba2ae4c8467bcc Mon Sep 17 00:00:00 2001
From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com>
Date: Mon, 20 Jan 2025 21:37:54 +0800
Subject: [PATCH] Add moondream2
* Add moondream2
* Update README.md
---
Cargo.toml | 3 +-
README.md | 3 +-
examples/florence2/main.rs | 8 +-
examples/moondream2/README.md | 10 +
examples/moondream2/main.rs | 157 ++++++++
src/misc/device.rs | 4 +-
src/misc/dtype.rs | 3 +
src/misc/engine.rs | 2 +
src/misc/options.rs | 15 +-
src/misc/scale.rs | 18 +
src/misc/task.rs | 34 +-
src/models/florence2/impl.rs | 2 +-
src/models/mod.rs | 2 +
src/models/moondream2/README.md | 9 +
src/models/moondream2/config.rs | 117 ++++++
src/models/moondream2/impl.rs | 645 ++++++++++++++++++++++++++++++++
src/models/moondream2/mod.rs | 4 +
src/models/yolo/impl.rs | 6 +-
18 files changed, 1019 insertions(+), 23 deletions(-)
create mode 100644 examples/moondream2/README.md
create mode 100644 examples/moondream2/main.rs
create mode 100644 src/models/moondream2/README.md
create mode 100644 src/models/moondream2/config.rs
create mode 100644 src/models/moondream2/impl.rs
create mode 100644 src/models/moondream2/mod.rs
diff --git a/Cargo.toml b/Cargo.toml
index 7d3f0e6..efed00d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,7 +14,7 @@ exclude = ["assets/*", "examples/*", "runs/*", "benches/*"]
aksr = { version = "0.0.2" }
image = { version = "0.25.2" }
imageproc = { version = "0.24" }
-ndarray = { version = "0.16.1", features = ["rayon"] }
+ndarray = { version = "0.16.1", features = ["rayon", "serde"] }
rayon = { version = "1.10.0" }
anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
@@ -38,6 +38,7 @@ natord = "1.0.9"
video-rs = { version = "0.10.0", features = ["ndarray"], optional = true }
minifb = { version = "0.27.0", optional = true }
sha2 = "0.10.8"
+ndarray-npy = "0.9.1"
[dev-dependencies]
argh = "0.1.13"
diff --git a/README.md b/README.md
index fb953e7..85a2517 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLO11](https://github.com/ultralytics/ultralytics)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RT-DETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro), [FastViT](https://github.com/apple/ml-fastvit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [MobileOne](https://github.com/apple/ml-mobileone)
-- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
+- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242), [Moondream2](https://github.com/vikhyat/moondream/tree/main)
- **OCR Models**: [FAST](https://github.com/czczup/FAST), [DB(PaddleOCR-Det)](https://arxiv.org/abs/1911.08947), [SVTR(PaddleOCR-Rec)](https://arxiv.org/abs/2205.00159), [SLANet](https://paddlepaddle.github.io/PaddleOCR/latest/algorithm/table_recognition/algorithm_table_slanet.html), [TrOCR](https://huggingface.co/microsoft/trocr-base-printed), [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
@@ -86,6 +86,7 @@
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Foundation for Human Vision Models | [demo](examples/sapiens) | ✅ | ✅ | ✅ | | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
+| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs
index 7248faf..52f7673 100644
--- a/examples/florence2/main.rs
+++ b/examples/florence2/main.rs
@@ -90,9 +90,11 @@ fn main() -> Result<()> {
Task::ObjectDetection,
Task::DenseRegionCaption,
// w/o inputs
- Task::OpenSetDetection("a vehicle"),
- Task::CaptionToPhraseGrounding("A vehicle with two wheels parked in front of a building."),
- Task::ReferringExpressionSegmentation("a vehicle"),
+ Task::OpenSetDetection("a vehicle".into()),
+ Task::CaptionToPhraseGrounding(
+ "A vehicle with two wheels parked in front of a building.".into(),
+ ),
+ Task::ReferringExpressionSegmentation("a vehicle".into()),
Task::RegionToSegmentation(
// 31, 156, 581, 373, // car
449, 270, 556, 372, // wheel
diff --git a/examples/moondream2/README.md b/examples/moondream2/README.md
new file mode 100644
index 0000000..e949db9
--- /dev/null
+++ b/examples/moondream2/README.md
@@ -0,0 +1,10 @@
+## Quick Start
+
+```shell
+cargo run -r -F cuda --example moondream2 -- --device 'cuda:0' --dtype i8 --scale 2b --task vqa:"What's in this image?"
+cargo run -r -F cuda --example moondream2 -- --device 'cuda:0' --dtype i8 --scale 2b --task cap:0
+cargo run -r -F cuda --example moondream2 -- --device 'cuda:0' --dtype i8 --scale 2b --task cap:1
+cargo run -r -F cuda --example moondream2 -- --device 'cuda:0' --dtype i8 --scale 2b --task open-od:person
+cargo run -r -F cuda --example moondream2 -- --device 'cuda:0' --dtype i8 --scale 2b --task open-kpt:person
+```
+
diff --git a/examples/moondream2/main.rs b/examples/moondream2/main.rs
new file mode 100644
index 0000000..299f590
--- /dev/null
+++ b/examples/moondream2/main.rs
@@ -0,0 +1,157 @@
+use anyhow::Result;
+use usls::{models::Moondream2, Annotator, DataLoader, Options, Scale, Task};
+
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+
+ /// source image
+ #[argh(
+ option,
+ default = "vec![
+ String::from(\"./assets/bus.jpg\"),
+ String::from(\"images/green-car.jpg\"),
+ ]"
+ )]
+ source: Vec,
+
+ /// dtype
+ #[argh(option, default = "String::from(\"int4\")")]
+ dtype: String,
+
+ /// scale
+ #[argh(option, default = "String::from(\"0.5b\")")]
+ scale: String,
+
+ /// task
+ #[argh(option, default = "String::from(\"Caption: 0\")")]
+ task: String,
+}
+
+fn main() -> Result<()> {
+ tracing_subscriber::fmt()
+ .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
+ .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
+ .init();
+ let args: Args = argh::from_env();
+
+ // build model
+ let (
+ options_vision_encoder,
+ options_vision_projection,
+ options_text_decoder,
+ options_text_encoder,
+ options_coord_decoder,
+ options_coord_encoder,
+ options_size_decoder,
+ options_size_encoder,
+ ) = match args.scale.as_str().try_into()? {
+ Scale::Billion(2.) => (
+ Options::moondream2_2b_vision_encoder(),
+ Options::moondream2_2b_vision_projection(),
+ Options::moondream2_2b_text_decoder(),
+ Options::moondream2_2b_text_encoder(),
+ Options::moondream2_2b_coord_decoder(),
+ Options::moondream2_2b_coord_encoder(),
+ Options::moondream2_2b_size_decoder(),
+ Options::moondream2_2b_size_encoder(),
+ ),
+ Scale::Billion(0.5) => (
+ Options::moondream2_0_5b_vision_encoder(),
+ Options::moondream2_0_5b_vision_projection(),
+ Options::moondream2_0_5b_text_decoder(),
+ Options::moondream2_0_5b_text_encoder(),
+ Options::moondream2_0_5b_coord_decoder(),
+ Options::moondream2_0_5b_coord_encoder(),
+ Options::moondream2_0_5b_size_decoder(),
+ Options::moondream2_0_5b_size_encoder(),
+ ),
+ _ => unimplemented!(),
+ };
+
+ let mut model = Moondream2::new(
+ options_vision_encoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ options_vision_projection
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ options_text_encoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ options_text_decoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ Some(
+ options_coord_encoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ ),
+ Some(
+ options_coord_decoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ ),
+ Some(
+ options_size_encoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ ),
+ Some(
+ options_size_decoder
+ .with_model_dtype(args.dtype.as_str().try_into()?)
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ ),
+ )?;
+
+ // load images
+ let xs = DataLoader::try_read_batch(&args.source)?;
+
+ // run with task
+ let task: Task = args.task.as_str().try_into()?;
+ let ys = model.forward(&xs, &task)?;
+
+ // annotate
+ match task {
+ Task::Caption(_) => {
+ println!("{}:", task);
+ for (i, y) in ys.iter().enumerate() {
+ if let Some(texts) = y.texts() {
+ println!("Image {}: {:?}\n", i, texts[0]);
+ }
+ }
+ }
+ Task::Vqa(query) => {
+ println!("Question: {}", query);
+ for (i, y) in ys.iter().enumerate() {
+ if let Some(texts) = y.texts() {
+ println!("Image {}: {:?}\n", i, texts[0]);
+ }
+ }
+ }
+ Task::OpenSetDetection(_) | Task::OpenSetKeypointsDetection(_) => {
+ println!("{:?}", ys);
+ let annotator = Annotator::default()
+ .with_bboxes_thickness(4)
+ .without_bboxes_conf(true)
+ .with_keypoints_radius(6)
+ .with_keypoints_name(true)
+ .with_saveout("moondream2");
+ annotator.annotate(&xs, &ys);
+ }
+ _ => unimplemented!("Unsupported moondream2 task."),
+ }
+
+ Ok(())
+}
diff --git a/src/misc/device.rs b/src/misc/device.rs
index e1029e1..ab04884 100644
--- a/src/misc/device.rs
+++ b/src/misc/device.rs
@@ -33,8 +33,8 @@ impl TryFrom<&str> for Device {
// device and its id
let d_id: Vec<&str> = s.trim().split(':').collect();
let (d, id) = match d_id.len() {
- 1 => (d_id[0], 0),
- 2 => (d_id[0], d_id[1].parse::().unwrap_or(0)),
+ 1 => (d_id[0].trim(), 0),
+ 2 => (d_id[0].trim(), d_id[1].trim().parse::().unwrap_or(0)),
_ => anyhow::bail!(
"Fail to parse device string: {s}. Expect: `device:device_id` or `device`. e.g. `cuda:0` or `cuda`"
),
diff --git a/src/misc/dtype.rs b/src/misc/dtype.rs
index 81f0d50..8e4dce2 100644
--- a/src/misc/dtype.rs
+++ b/src/misc/dtype.rs
@@ -3,6 +3,7 @@ use ort::tensor::TensorElementType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum DType {
Auto,
+ Int4,
Int8,
Int16,
Int32,
@@ -32,6 +33,7 @@ impl TryFrom<&str> for DType {
"u16" | "uint16" => Ok(Self::Uint16),
"u32" | "uint32" => Ok(Self::Uint32),
"u64" | "uint64" => Ok(Self::Uint64),
+ "i4" | "int4" => Ok(Self::Int4),
"i8" | "int8" => Ok(Self::Int8),
"i16" | "int=16" => Ok(Self::Int16),
"i32" | "int32" => Ok(Self::Int32),
@@ -52,6 +54,7 @@ impl std::fmt::Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let x = match self {
Self::Auto => "auto",
+ Self::Int4 => "int4",
Self::Int8 => "int8",
Self::Int16 => "int16",
Self::Int32 => "int32",
diff --git a/src/misc/engine.rs b/src/misc/engine.rs
index a99ed27..04999ea 100644
--- a/src/misc/engine.rs
+++ b/src/misc/engine.rs
@@ -206,6 +206,7 @@ impl Engine {
x, dtype,
)?));
}
+
xs_
});
@@ -223,6 +224,7 @@ impl Engine {
ys.push_kv(name.as_str(), X::from(y))?;
}
});
+
Ok(ys)
} else {
anyhow::bail!("Failed to run with ONNXRuntime. No model info found.");
diff --git a/src/misc/options.rs b/src/misc/options.rs
index bc98179..5c77c4f 100644
--- a/src/misc/options.rs
+++ b/src/misc/options.rs
@@ -68,6 +68,11 @@ pub struct Options {
pub text_confs_2: Vec,
pub text_confs_3: Vec,
+ // Files
+ pub file: Option,
+ pub file_2: Option,
+ pub file_3: Option,
+
// For classification
pub apply_softmax: Option,
@@ -149,6 +154,9 @@ impl Default for Options {
text_names: None,
text_names_2: None,
text_names_3: None,
+ file: None,
+ file_2: None,
+ file_3: None,
class_confs: vec![0.3f32],
class_confs_2: vec![0.3f32],
class_confs_3: vec![0.3f32],
@@ -320,11 +328,6 @@ impl Options {
.try_fetch(&format!("{}/{}", self.model_name, self.model_file))?;
}
}
-
- // let stem = crate::try_fetch_stem(&self.model_file)?;
- // self.model_spec = format!("{}/{}", self.model_name, stem);
- // self.model_file =
- // Hub::default().try_fetch(&format!("{}/{}", self.model_name, self.model_file))?;
}
Ok(self)
@@ -408,7 +411,7 @@ impl Options {
.unwrap_or(&format!("{}/tokenizer.json", self.model_name)),
)?,
)
- .map_err(|_| anyhow::anyhow!("No `tokenizer.json` found"))?;
+ .map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?;
// TODO: padding
// if `max_length` specified: use `Fixed` strategy
diff --git a/src/misc/scale.rs b/src/misc/scale.rs
index 4dc5ab4..ecab3f9 100644
--- a/src/misc/scale.rs
+++ b/src/misc/scale.rs
@@ -13,6 +13,8 @@ pub enum Scale {
P,
A,
F,
+ Million(f32),
+ Billion(f32),
}
impl std::fmt::Display for Scale {
@@ -31,6 +33,8 @@ impl std::fmt::Display for Scale {
Self::P => "p",
Self::A => "a",
Self::F => "f",
+ Self::Million(x) => &format!("{x}m"),
+ Self::Billion(x) => &format!("{x}b"), // x.0 -> x
};
write!(f, "{}", x)
}
@@ -77,6 +81,20 @@ impl TryFrom<&str> for Scale {
"p" | "pico" => Ok(Self::P),
"a" | "atto" => Ok(Self::A),
"f" | "femto" => Ok(Self::F),
+ scale if scale.ends_with("b") => {
+ let num_str = &scale[..scale.len() - 1];
+ match num_str.parse::() {
+ Ok(x) => Ok(Self::Billion(x)),
+ Err(_) => anyhow::bail!("Invalid Billion format: {}", scale),
+ }
+ }
+ scale if scale.ends_with("m") => {
+ let num_str = &scale[..scale.len() - 1];
+ match num_str.parse::() {
+ Ok(x) => Ok(Self::Million(x)),
+ Err(_) => anyhow::bail!("Invalid Million format: {}", scale),
+ }
+ }
x => anyhow::bail!("Unsupported model scale: {:?}", x),
}
}
diff --git a/src/misc/task.rs b/src/misc/task.rs
index 80e5c33..c10dece 100644
--- a/src/misc/task.rs
+++ b/src/misc/task.rs
@@ -1,4 +1,4 @@
-#[derive(Debug, Copy, Clone, Ord, Eq, PartialOrd, PartialEq)]
+#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
pub enum Task {
/// Image classification task.
/// Input: image
@@ -32,7 +32,7 @@ pub enum Task {
/// Input: image
/// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores
/// Open set detection task, with String query
- OpenSetDetection(&'static str),
+ OpenSetDetection(String),
/// Task for generating brief descriptions of dense regions in the image.
/// Input: image
/// Output: bounding boxes (bboxes), brief phrase labels, and optional scores for detected regions
@@ -44,6 +44,7 @@ pub enum Task {
/// Output: coordinates of detected keypoints
KeypointsDetection,
Pose,
+ OpenSetKeypointsDetection(String),
/// Semantic segmentation task, segmenting the image into different semantic regions.
/// Input: image
@@ -97,12 +98,12 @@ pub enum Task {
/// Input: image and text
/// Output: image region and the corresponding phrase
/// caption to phrase grounding
- CaptionToPhraseGrounding(&'static str),
+ CaptionToPhraseGrounding(String),
/// Referring expression segmentation task, segmenting objects in the image based on a text description.
/// Input: image and referring expression
/// Output: a segmentation mask for the object referred to by the text
- ReferringExpressionSegmentation(&'static str),
+ ReferringExpressionSegmentation(String),
/// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM).
/// Input: image and region proposals
@@ -125,7 +126,7 @@ pub enum Task {
/// Visual question answering (VQA) task, answering questions related to an image.
/// Input: image and question text
/// Output: the answer to the question
- Vqa(&'static str),
+ Vqa(String),
/// Optical character recognition (OCR) task, recognizing text in an image.
/// Input: image
@@ -156,6 +157,7 @@ impl std::fmt::Display for Task {
Self::Ocr => "ocr",
Self::OcrWithRegion => "ocr-with-region",
Self::Vqa(_) => "vqa",
+ Self::OpenSetKeypointsDetection(_) => "open-set-keypoints-detection",
_ => todo!(),
};
write!(f, "{}", x)
@@ -166,13 +168,33 @@ impl TryFrom<&str> for Task {
type Error = anyhow::Error;
fn try_from(s: &str) -> Result {
+ // TODO
match s.to_lowercase().as_str() {
"cls" | "classify" | "classification" => Ok(Self::ImageClassification),
"det" | "od" | "detect" => Ok(Self::ObjectDetection),
"kpt" | "pose" => Ok(Self::KeypointsDetection),
"seg" | "segment" => Ok(Self::InstanceSegmentation),
"obb" => Ok(Self::OrientedObjectDetection),
- _ => todo!(), // x => anyhow::bail!("Unsupported model task: {}", x),
+ "cap" | "cap0" | "caption" => Ok(Self::Caption(0)),
+ "cap1" | "caption1" => Ok(Self::Caption(1)),
+ "cap2" | "caption2" => Ok(Self::Caption(2)),
+ x if x.contains(":") => {
+ let t_tt: Vec<&str> = x.trim().split(':').collect();
+ let (t, tt) = match t_tt.len() {
+ 2 => (t_tt[0].trim(), t_tt[1].trim()),
+ _ => anyhow::bail!(
+ "Fail to parse task: {x}. Expect: `task:content`. e.g. `vqa:What's in this image?`"
+ ),
+ };
+ match t {
+ "cap" | "caption" => Ok(Self::Caption(tt.parse::().unwrap_or(0) as u8)),
+ "vqa" => Ok(Self::Vqa(tt.into())),
+ "open-det" | "open-od" => Ok(Self::OpenSetDetection(tt.into())),
+ "open-kpt" | "open-pose" => Ok(Self::OpenSetKeypointsDetection(tt.into())),
+ _ => todo!(),
+ }
+ }
+ _ => todo!(),
}
}
}
diff --git a/src/models/florence2/impl.rs b/src/models/florence2/impl.rs
index b4094e2..b138d7e 100644
--- a/src/models/florence2/impl.rs
+++ b/src/models/florence2/impl.rs
@@ -88,7 +88,7 @@ impl Florence2 {
.quantize(&[*x0, *y0, *x1, *y1], (image_width, image_height));
Task::RegionToDescription(xyxy[0], xyxy[1], xyxy[2], xyxy[3])
}
- _ => *task,
+ _ => task.clone(),
}
}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index 79db3c5..9b5dc02 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -16,6 +16,7 @@ mod grounding_dino;
mod linknet;
mod mobileone;
mod modnet;
+mod moondream2;
mod picodet;
mod pipeline;
mod rtdetr;
@@ -37,6 +38,7 @@ pub use dinov2::*;
pub use florence2::*;
pub use grounding_dino::*;
pub use modnet::*;
+pub use moondream2::*;
pub use picodet::*;
pub use pipeline::*;
pub use rtdetr::*;
diff --git a/src/models/moondream2/README.md b/src/models/moondream2/README.md
new file mode 100644
index 0000000..59e37b3
--- /dev/null
+++ b/src/models/moondream2/README.md
@@ -0,0 +1,9 @@
+# moondream: A tiny vision language model that kicks ass and runs anywhere
+
+## Official Repository
+
+The official repository can be found on: [GitHub](https://github.com/vikhyat/moondream/tree/main)
+
+## Example
+
+Refer to the [example](../../../examples/moondream2)
diff --git a/src/models/moondream2/config.rs b/src/models/moondream2/config.rs
new file mode 100644
index 0000000..96d0bf9
--- /dev/null
+++ b/src/models/moondream2/config.rs
@@ -0,0 +1,117 @@
+/// Model configuration for `moondream2`
+impl crate::Options {
+ pub fn moondream2() -> Self {
+ Self::default()
+ .with_model_name("moondream2")
+ .with_model_num_dry_run(0)
+ }
+
+ pub fn moondream2_0_5b() -> Self {
+ Self::moondream2().with_model_scale(crate::Scale::Billion(0.5))
+ }
+
+ pub fn moondream2_0_5b_vision_encoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_model_ixx(0, 0, (1, 3, 4).into()) // patch count
+ .with_model_kind(crate::Kind::Vision)
+ .with_image_mean(&[0.5, 0.5, 0.5])
+ .with_image_std(&[0.5, 0.5, 0.5])
+ .with_normalize(true)
+ .with_resize_mode(crate::ResizeMode::FitExact)
+ .with_resize_filter("catmullrom")
+ .with_model_file("0.5b-vision-encoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_vision_projection() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_kind(crate::Kind::Vision)
+ .with_model_file("0.5b-vision-projection.onnx")
+ }
+
+ pub fn moondream2_0_5b_text_decoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_kind(crate::Kind::Language)
+ .with_model_file("0.5b-text-decoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_text_encoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_kind(crate::Kind::Language)
+ .with_model_file("0.5b-text-encoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_coord_encoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_file("0.5b-coord-encoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_coord_decoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_file("0.5b-coord-decoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_size_encoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_file("0.5b-size-encoder.onnx")
+ }
+
+ pub fn moondream2_0_5b_size_decoder() -> Self {
+ Self::moondream2_0_5b()
+ .with_batch_size(1)
+ .with_model_file("0.5b-size-decoder.onnx")
+ }
+
+ pub fn moondream2_2b_vision_encoder() -> Self {
+ Self::moondream2_0_5b_vision_encoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-vision-encoder.onnx")
+ }
+
+ pub fn moondream2_2b_vision_projection() -> Self {
+ Self::moondream2_0_5b_vision_projection()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-vision-projection.onnx")
+ }
+
+ pub fn moondream2_2b_text_decoder() -> Self {
+ Self::moondream2_0_5b_text_decoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-text-decoder.onnx")
+ }
+
+ pub fn moondream2_2b_text_encoder() -> Self {
+ Self::moondream2_0_5b_text_encoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-text-encoder.onnx")
+ }
+
+ pub fn moondream2_2b_coord_encoder() -> Self {
+ Self::moondream2_0_5b_coord_encoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-coord-encoder.onnx")
+ }
+
+ pub fn moondream2_2b_coord_decoder() -> Self {
+ Self::moondream2_0_5b_coord_decoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-coord-decoder.onnx")
+ }
+
+ pub fn moondream2_2b_size_encoder() -> Self {
+ Self::moondream2_0_5b_size_encoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-size-encoder.onnx")
+ }
+
+ pub fn moondream2_2b_size_decoder() -> Self {
+ Self::moondream2_0_5b_size_decoder()
+ .with_model_scale(crate::Scale::Billion(2.))
+ .with_model_file("2b-size-decoder.onnx")
+ }
+}
diff --git a/src/models/moondream2/impl.rs b/src/models/moondream2/impl.rs
new file mode 100644
index 0000000..cecd110
--- /dev/null
+++ b/src/models/moondream2/impl.rs
@@ -0,0 +1,645 @@
+use aksr::Builder;
+use anyhow::{Context, Result};
+use image::{DynamicImage, GenericImageView};
+use ndarray::{s, Array, Array2, Array3, Axis, IxDyn};
+use ndarray_npy::ReadNpyExt;
+
+use crate::{
+ BaseModelTextual, Bbox, DType, Engine, Hub, Keypoint, LogitsSampler, Options, Processor, Scale,
+ Task, Ts, Xs, Ys, X, Y,
+};
+
+#[derive(Builder, Debug)]
+pub struct Moondream2 {
+ vision_encoder: VisionEncoder,
+ vision_projection: VisionProjection,
+ pub text_decoder: BaseModelTextual,
+ text_encoder: BaseModelTextual,
+ coord_decoder: Option,
+ coord_encoder: Option,
+ size_decoder: Option,
+ size_encoder: Option,
+ initial_kv_cache: X, // TODO: use f16
+ scale: Scale,
+ dtype: DType,
+ max_length: usize,
+ eos_token_id: u32,
+ max_objects: usize,
+}
+
+impl Moondream2 {
+ // TODO
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ options_vision_encoder: Options,
+ options_vision_projection: Options,
+ options_text_encoder: Options,
+ options_text_decoder: Options,
+ options_coord_encoder: Option,
+ options_coord_decoder: Option,
+ options_size_encoder: Option,
+ options_size_decoder: Option,
+ ) -> Result {
+ let max_length = 2048;
+ let max_objects = 50;
+ let eos_token_id = 50256;
+ let dtype = options_vision_encoder.model_dtype;
+ let scale = options_vision_encoder
+ .model_scale
+ .unwrap_or(Scale::Billion(0.5));
+ let initial_kv_cache: X = KVCache::new(&scale, &dtype)?.0.into();
+ let vision_encoder = VisionEncoder::new(options_vision_encoder)?;
+ let vision_projection = VisionProjection::new(options_vision_projection)?;
+ let text_decoder = BaseModelTextual::new(options_text_decoder)?;
+ let text_encoder = BaseModelTextual::new(options_text_encoder)?;
+ let coord_decoder = options_coord_decoder
+ .map(BaseModelTextual::new)
+ .transpose()?;
+ let coord_encoder = options_coord_encoder
+ .map(BaseModelTextual::new)
+ .transpose()?;
+ let size_decoder = options_size_decoder
+ .map(BaseModelTextual::new)
+ .transpose()?;
+ let size_encoder = options_size_encoder
+ .map(BaseModelTextual::new)
+ .transpose()?;
+
+ Ok(Self {
+ vision_encoder,
+ vision_projection,
+ text_decoder,
+ initial_kv_cache,
+ max_length,
+ max_objects,
+ text_encoder,
+ coord_decoder,
+ coord_encoder,
+ size_encoder,
+ size_decoder,
+ eos_token_id,
+ scale,
+ dtype,
+ })
+ }
+
+ pub fn encode_image(&mut self, x: &DynamicImage) -> Result {
+ let patches_emb = self.vision_encoder.encode(x)?.clone().insert_axis(0)?;
+ let image_embedding = self.vision_projection.inference(patches_emb.into())?[0].to_owned();
+
+ Ok(image_embedding)
+ }
+
+ pub fn forward(&mut self, xs: &[DynamicImage], task: &Task) -> Result {
+ let mut ys: Vec = Vec::new();
+ for x in xs.iter() {
+ let y = self.forward_once(x, task)?;
+ ys.push(y);
+ }
+
+ Ok(ys.into())
+ }
+
+ pub fn forward_once(&mut self, images: &DynamicImage, task: &Task) -> Result {
+ let image_embedding = self.encode_image(images)?;
+ let kv_cache = self.prepare_kv_cache(&image_embedding)?;
+
+ match task {
+ Task::Caption(n) => {
+ let input_ids = match n {
+ 0 => vec![198., 198., 16438., 8305., 25.],
+ _ => vec![198., 198., 24334., 1159., 25.],
+ };
+ let text = self.generate_text(&input_ids, kv_cache)?;
+ let y = Y::default().with_texts(&[text.into()]);
+
+ Ok(y)
+ }
+ Task::Vqa(query) => {
+ let input_ids: Vec<_> = [198., 198., 24361., 25.]
+ .iter()
+ .chain(
+ &self
+ .text_encoder
+ .processor()
+ .encode_text_ids(query, false)?,
+ )
+ .chain(&[198., 198., 33706., 25.])
+ .cloned()
+ .collect();
+
+ let text = self.generate_text(&input_ids, kv_cache)?;
+ let y = Y::default().with_texts(&[text.into()]);
+
+ Ok(y)
+ }
+ Task::OpenSetDetection(object) => {
+ let input_ids: Vec<_> = [198., 198., 47504., 25.]
+ .iter()
+ .chain(
+ &self
+ .text_encoder
+ .processor()
+ .encode_text_ids(&format!(" {}", object), false)?,
+ )
+ .chain(&[628.])
+ .cloned()
+ .collect();
+ let (_, y_bboxes) =
+ self.generate_points_boxes(&input_ids, kv_cache, object, true)?;
+
+ Ok(Y::default().with_bboxes(&y_bboxes))
+ }
+ Task::OpenSetKeypointsDetection(object) => {
+ let input_ids: Vec<_> = [198., 198., 12727., 25.]
+ .iter()
+ .chain(
+ &self
+ .text_encoder
+ .processor()
+ .encode_text_ids(&format!(" {}", object), false)?,
+ )
+ .chain(&[628.])
+ .cloned()
+ .collect();
+ let (y_kpts, _) =
+ self.generate_points_boxes(&input_ids, kv_cache, object, false)?;
+
+ Ok(Y::default().with_keypoints(&y_kpts))
+ }
+ x => anyhow::bail!("Unsupported Moondream2 task: {}", x),
+ }
+ }
+
+ fn generate_text(&mut self, input_ids: &[f32], kv_cache: Array) -> Result {
+ let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?;
+ let mut input_embeds = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned();
+ let logits_sampler = LogitsSampler::new();
+ let mut token_ids: Vec = Vec::new();
+ let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4];
+ let mut inc = input_embeds.shape()[1];
+ let mut kv_cache = kv_cache.clone();
+
+ // generate
+ for _ in 0..self.max_length {
+ // TODO
+ let input = Xs::from(vec![
+ input_embeds.clone(),
+ kv_cache
+ .slice(s![.., .., .., .., ..pos, ..])
+ .into_owned()
+ .into_dyn()
+ .into(),
+ ]);
+ let decoder_outputs = self.text_decoder.inference(input)?;
+
+ // update
+ let logits = &decoder_outputs["logits"];
+ let new_kv_cache = &decoder_outputs["new_kv_cache"];
+ kv_cache
+ .slice_mut(s![.., .., .., .., pos..pos + inc, ..])
+ .assign(new_kv_cache);
+ pos += inc;
+
+ // decode
+ let token_id = logits_sampler.decode(
+ logits
+ .slice(s![-1, ..])
+ .as_slice()
+ .context("Failed to get slice when decode `logits`")?,
+ )?;
+
+ // break
+ if token_id == self.eos_token_id {
+ break;
+ }
+
+ // update
+ token_ids.push(token_id);
+ inc = 1;
+
+ // encode
+ let next_tokens = X::from(vec![token_id as f32]).insert_axis(1)?;
+ input_embeds = self.text_encoder.inference(Xs::from(next_tokens))?[0].to_owned();
+ }
+
+ let text = self
+ .text_encoder
+ .processor()
+ .decode_tokens(&token_ids, true)?;
+
+ Ok(text)
+ }
+
+ fn generate_points_boxes(
+ &mut self,
+ input_ids: &[f32],
+ kv_cache: Array,
+ object: &str,
+ generate_boxes: bool,
+ ) -> Result<(Vec>, Vec)> {
+ let mut y_bboxes: Vec = Vec::new();
+ let mut y_kpts: Vec> = Vec::new();
+ let (image_height, image_width) = self.vision_encoder.processor.image0s_size[0];
+ let mut pos = self.vision_projection.seq_len() + self.initial_kv_cache.shape()[4];
+ let logits_sampler = LogitsSampler::new();
+
+ // initial input_embeds
+ let input_ids = X::from(input_ids.to_vec()).insert_axis(0)?;
+ let mut hidden = self.text_encoder.inference(Xs::from(input_ids))?[0].to_owned();
+ let mut kv_cache = kv_cache;
+
+ // generate
+ loop {
+ let logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?;
+
+ // decode
+ let token_id = logits_sampler.decode(
+ logits
+ .slice(s![-1, ..])
+ .as_slice()
+ .context("Failed to get slice for `logits`")?,
+ )?;
+
+ // break
+ if token_id == self.eos_token_id {
+ break;
+ }
+
+ // cx
+ let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
+ let cx = self
+ .coord_decoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(input))?[0]
+ .clone(); // [1024]
+ let ratio = cx.shape()[0] as f32;
+ let cx = logits_sampler
+ .decode(cx.as_slice().context("Failed to get slice for `cx`")?)?
+ as f32
+ / ratio;
+ hidden = self
+ .coord_encoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(X::from(vec![cx])))?[0]
+ .clone()
+ .insert_axis(0)?
+ .insert_axis(0)?;
+
+ // cy
+ let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?;
+ let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
+ let cy = self
+ .coord_decoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(input))?[0]
+ .clone();
+ let ratio = cy.shape()[0] as f32;
+
+ let cy = logits_sampler
+ .decode(cy.as_slice().context("Failed to get slice for `cy`")?)?
+ as f32
+ / ratio;
+
+ hidden = self
+ .coord_encoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(X::from(vec![cy])))?[0]
+ .clone()
+ .insert_axis(0)?
+ .insert_axis(0)?;
+
+ if !generate_boxes {
+ y_kpts.push(vec![Keypoint::from((
+ cx * image_width as f32,
+ cy * image_height as f32,
+ 0,
+ ))
+ .with_name(object)]);
+
+ // keep?
+ if y_kpts.len() > self.max_objects {
+ break;
+ }
+ } else {
+ // wh
+ let _logits = self.run_decoder(&mut hidden, &mut kv_cache, &mut pos)?;
+ let input: X = hidden.slice(s![0, -1, ..]).into_owned().into_dyn().into();
+ let size = self
+ .size_decoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(input))?[0]
+ .clone(); // [2, 1024]
+
+ let ratio = size.shape()[1] as f32;
+ let w = logits_sampler.decode(
+ size.slice(s![0, ..])
+ .as_slice()
+ .context("Failed to get slice when decode `w`")?,
+ )? as f32
+ / ratio;
+
+ // h
+ let h = logits_sampler.decode(
+ size.slice(s![1, ..])
+ .as_slice()
+ .context("Failed to get slice when decode `h`")?,
+ )? as f32
+ / ratio;
+
+ hidden = self
+ .size_encoder
+ .as_mut()
+ .unwrap()
+ .inference(Xs::from(X::from(vec![w, h])))?[0]
+ .clone()
+ .insert_axis(0)?
+ .insert_axis(0)?; // [1024]
+
+ let xmin = cx - w / 2.;
+ let ymin = cy - h / 2.;
+
+ y_bboxes.push(
+ Bbox::from((
+ xmin * image_width as f32,
+ ymin * image_height as f32,
+ w * image_width as f32,
+ h * image_height as f32,
+ ))
+ .with_name(object)
+ .with_id(0)
+ .with_confidence(1.),
+ );
+
+ // Keep?
+ if y_bboxes.len() > self.max_objects {
+ break;
+ }
+ }
+ }
+
+ Ok((y_kpts, y_bboxes))
+ }
+
+ fn prepare_kv_cache(&mut self, image_embedding: &X) -> Result> {
+ let kv_cache_new = self.text_decoder.inference(Xs::from(vec![
+ image_embedding.clone(),
+ self.initial_kv_cache.clone(),
+ ]))?["new_kv_cache"]
+ .to_owned();
+
+ // TODO
+ let kv_cache_new = ndarray::concatenate(
+ Axis(4),
+ &[kv_cache_new.view(), self.initial_kv_cache.view()],
+ )?;
+
+ // fill with max sequence length
+ let mut shapes = self.initial_kv_cache.shape().to_vec();
+ shapes[4] = self.max_length;
+ let mut kv_cache = Array::zeros(shapes);
+ kv_cache
+ .slice_mut(s![.., .., .., .., ..kv_cache_new.dim()[4], ..])
+ .assign(&kv_cache_new);
+
+ Ok(kv_cache.into_dyn())
+ }
+
+ fn run_decoder(
+ &mut self,
+ input_embeds: &mut X,
+ kv_cache: &mut Array,
+ pos: &mut usize,
+ ) -> Result {
+ let decoder_outputs = self.text_decoder.inference(Xs::from(vec![
+ input_embeds.clone(),
+ kv_cache
+ .slice(s![.., .., .., .., ..*pos, ..])
+ .into_owned()
+ .into_dyn()
+ .into(),
+ ]))?;
+ let hidden = &decoder_outputs["hidden"];
+ let new_kv_cache = &decoder_outputs["new_kv_cache"];
+
+ // update
+ let inc = hidden.shape()[1]; // -2
+ kv_cache
+ .slice_mut(s![.., .., .., .., *pos..*pos + inc, ..])
+ .assign(new_kv_cache);
+ *pos += inc;
+ *input_embeds = hidden.to_owned();
+
+ Ok(decoder_outputs["logits"].to_owned())
+ }
+}
+
+#[derive(Debug, Builder)]
+pub struct VisionEncoder {
+ engine: Engine,
+ num_patch: usize,
+ patch_size: usize,
+ processor: Processor,
+ ts: Ts,
+}
+
+impl VisionEncoder {
+ pub fn new(options: Options) -> Result {
+ let engine = options.to_engine()?;
+ let (num_patch, patch_size, ts) = (
+ engine.batch().opt(),
+ engine.try_height().unwrap_or(&378.into()).opt(),
+ engine.ts.clone(),
+ );
+ let processor = options
+ .to_processor()?
+ .with_image_width(patch_size as _)
+ .with_image_height(patch_size as _);
+
+ Ok(Self {
+ engine,
+ patch_size,
+ num_patch,
+ processor,
+ ts,
+ })
+ }
+
+ fn create_patches(
+ image: &DynamicImage,
+ image_patch_size: usize,
+ ) -> (Vec, (u32, u32)) {
+ let mut patches = vec![image.clone()];
+ let image = image.to_rgb8();
+
+ let res_templates = vec![(1, 2), (2, 1), (2, 2)];
+ let (im_width, im_height) = image.dimensions();
+ let max_dim = im_width.max(im_height);
+ let selected_template = if max_dim < (image_patch_size as f32 * 1.4) as u32 {
+ (1, 1)
+ } else {
+ let aspect_ratio = im_width as f32 / im_height as f32;
+ res_templates
+ .into_iter()
+ .min_by(|a, b| {
+ let diff_a = ((a.1 as f32 / a.0 as f32) - aspect_ratio).abs();
+ let diff_b = ((b.1 as f32 / b.0 as f32) - aspect_ratio).abs();
+ diff_a.partial_cmp(&diff_b).unwrap()
+ })
+ .unwrap()
+ };
+ let patch_width = im_width / selected_template.1;
+ let patch_height = im_height / selected_template.0;
+
+ for row in 0..selected_template.0 {
+ for col in 0..selected_template.1 {
+ let x_min = col * patch_width;
+ let y_min = row * patch_height;
+ let _x_max = x_min + patch_width;
+ let _y_max = y_min + patch_height;
+ let cropped = image
+ .view(x_min, y_min, patch_width, patch_height)
+ .to_image();
+
+ patches.push(DynamicImage::from(cropped));
+ }
+ }
+
+ (patches, selected_template)
+ }
+
+ pub fn inference(&mut self, xs: Xs) -> Result {
+ self.engine.run(xs)
+ }
+
+ pub fn encode(&mut self, x: &DynamicImage) -> Result {
+ let (patches, selected_template) = Self::create_patches(x, self.patch_size);
+ let patches = self.processor.process_images(&patches)?;
+ let template = (
+ (selected_template.0 as usize),
+ (selected_template.1 as usize),
+ );
+ let patch_emb = self.inference(patches.clone().into())?[0].clone();
+ let patch_emb = patch_emb.clone().0.into_dimensionality::()?;
+ let patch_emb = Self::process_patch_emb(patch_emb, template)?;
+ let patch_emb = X::from(patch_emb.into_dyn()); // TODO .insert_axis(x),
+
+ Ok(patch_emb)
+ }
+
+ fn process_patch_emb(patch_emb: Array3, template: (usize, usize)) -> Result> {
+ let (_, seq_len, enc_dim) = patch_emb.dim(); // (N, 729, 720)
+ let global_patch = patch_emb.slice(s![0, .., ..]).into_owned();
+ if template == (1, 1) {
+ Ok(ndarray::concatenate(
+ Axis(1),
+ &[global_patch.view(), global_patch.view()],
+ )?)
+ } else {
+ let w = (seq_len as f32).sqrt() as usize;
+ let mut rows = Vec::new();
+ for r in 0..template.0 {
+ let mut row = Vec::new();
+ for c in 0..template.1 {
+ let idx = r * template.1 + c;
+ let patch = patch_emb.slice(s![idx, .., ..]).into_owned();
+ let patch = patch.into_shape_with_order((w, w, enc_dim))?;
+ row.push(patch);
+ }
+ let row_concat = ndarray::concatenate(
+ Axis(1),
+ &row.iter().map(|x| x.view()).collect::>(),
+ )?;
+ rows.push(row_concat);
+ }
+
+ let patch_emb =
+ ndarray::concatenate(Axis(0), &rows.iter().map(|x| x.view()).collect::>())?;
+ let patch_emb = Self::adaptive_avg_pool2d(patch_emb, (w, w))
+ .into_shape_with_order((w * w, enc_dim))?;
+
+ Ok(ndarray::concatenate(
+ Axis(1),
+ &[global_patch.view(), patch_emb.view()],
+ )?)
+ }
+ }
+
+ fn adaptive_avg_pool2d(x: Array3, output_size: (usize, usize)) -> Array3 {
+ let (height, width, channels) = x.dim();
+ let (out_height, out_width) = output_size;
+ let stride_h = height / out_height;
+ let stride_w = width / out_width;
+ let kernel_h = height - (out_height - 1) * stride_h;
+ let kernel_w = width - (out_width - 1) * stride_w;
+ let mut output = Array3::zeros((out_height, out_width, channels));
+ for i in 0..out_height {
+ for j in 0..out_width {
+ let h_start = i * stride_h;
+ let h_end = h_start + kernel_h;
+ let w_start = j * stride_w;
+ let w_end = w_start + kernel_w;
+
+ for c in 0..channels {
+ let mut sum = 0.0;
+ let mut count = 0;
+
+ for h in h_start..h_end {
+ for w in w_start..w_end {
+ if h < height && w < width {
+ sum += x[(h, w, c)];
+ count += 1;
+ }
+ }
+ }
+ output[(i, j, c)] = sum / count as f32;
+ }
+ }
+ }
+
+ output
+ }
+}
+
+#[derive(Debug, Builder)]
+pub struct VisionProjection {
+ engine: Engine,
+ seq_len: usize,
+ ts: Ts,
+}
+
+impl VisionProjection {
+ pub fn new(options: Options) -> Result {
+ let engine = options.to_engine()?;
+ let (seq_len, ts) = (engine.inputs_minoptmax[0][1].opt(), engine.ts.clone());
+
+ Ok(Self {
+ engine,
+ seq_len,
+ ts,
+ })
+ }
+
+ pub fn inference(&mut self, xs: Xs) -> Result {
+ self.engine.run(xs)
+ }
+}
+
+#[derive(Builder, Debug)]
+struct KVCache(pub Array);
+
+impl KVCache {
+ pub fn new(scale: &Scale, dtype: &DType) -> Result {
+ let f = format!("moondream2/{}-initial-kv-cache-{}.npy", scale, dtype);
+ let f = Hub::default().try_fetch(&f)?;
+ let file = std::fs::File::open(f)?;
+ let x = Array::::read_npy(file)?.into_dyn();
+
+ Ok(Self(x))
+ }
+}
diff --git a/src/models/moondream2/mod.rs b/src/models/moondream2/mod.rs
new file mode 100644
index 0000000..53f1e2c
--- /dev/null
+++ b/src/models/moondream2/mod.rs
@@ -0,0 +1,4 @@
+mod config;
+mod r#impl;
+
+pub use r#impl::Moondream2;
diff --git a/src/models/yolo/impl.rs b/src/models/yolo/impl.rs
index 396b602..45dfa62 100644
--- a/src/models/yolo/impl.rs
+++ b/src/models/yolo/impl.rs
@@ -59,8 +59,8 @@ impl YOLO {
.to_processor()?
.with_image_width(width as _)
.with_image_height(height as _);
- let task: Option = match options.model_task {
- Some(task) => Some(task),
+ let task: Option = match &options.model_task {
+ Some(task) => Some(task.clone()),
None => match engine.try_fetch("task") {
Some(x) => match x.as_str() {
"classify" => Some(Task::ImageClassification),
@@ -104,7 +104,7 @@ impl YOLO {
// version + task
None => match (task, version) {
(Some(task), Some(version)) => {
- let layout = match (task, version) {
+ let layout = match (task.clone(), version) {
(Task::ImageClassification, Version(5, 0)) => {
YOLOPredsFormat::n_clss().apply_softmax(true)
}