From e2347353ba6fd13b9d8d85c4bd270b8a72418b2f Mon Sep 17 00:00:00 2001
From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com>
Date: Sat, 8 Feb 2025 00:28:35 +0800
Subject: [PATCH] Add SmolVLM model (#65)
* Add SmolVLM model
---
Cargo.toml | 4 +-
README.md | 1 +
examples/smolvlm/README.md | 8 +
examples/smolvlm/main.rs | 74 ++++++++
src/models/mod.rs | 2 +
src/models/smolvlm/README.md | 11 ++
src/models/smolvlm/config.rs | 58 +++++++
src/models/smolvlm/impl.rs | 323 +++++++++++++++++++++++++++++++++++
src/models/smolvlm/mod.rs | 4 +
src/xy/text.rs | 6 +
10 files changed, 489 insertions(+), 2 deletions(-)
create mode 100644 examples/smolvlm/README.md
create mode 100644 examples/smolvlm/main.rs
create mode 100644 src/models/smolvlm/README.md
create mode 100644 src/models/smolvlm/config.rs
create mode 100644 src/models/smolvlm/impl.rs
create mode 100644 src/models/smolvlm/mod.rs
diff --git a/Cargo.toml b/Cargo.toml
index efed00d..4eb29dc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,7 +20,7 @@ anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
rand = { version = "0.8.5" }
chrono = { version = "0.4.30" }
-tokenizers = { version = "0.15.2" }
+tokenizers = { version = "0.21.0" }
log = { version = "0.4.22" }
indicatif = "0.17.8"
serde_json = "1.0"
@@ -62,6 +62,6 @@ trt = [ "ort/tensorrt" ]
mps = [ "ort/coreml" ]
[profile.release]
-# lto = true
+lto = true
strip = true
panic = "abort"
diff --git a/README.md b/README.md
index d78eb5c..2c21a6a 100644
--- a/README.md
+++ b/README.md
@@ -88,6 +88,7 @@
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection
Open-Set Keypoints Detection
Image Caption
Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
+| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
diff --git a/examples/smolvlm/README.md b/examples/smolvlm/README.md
new file mode 100644
index 0000000..e0c8405
--- /dev/null
+++ b/examples/smolvlm/README.md
@@ -0,0 +1,8 @@
+## Quick Start
+
+```shell
+cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What's in it?"
+cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What color is the car?"
+cargo run -r --example smolvlm -- --scale 500m --source "images/slanted-text-number.jpg" --prompt "What are these numbers?"
+cargo run -r --example smolvlm -- --scale 256m --source "images/Statue-of-Liberty-Island-New-York-Bay.jpg" --prompt "Can you describe this image?"
+```
\ No newline at end of file
diff --git a/examples/smolvlm/main.rs b/examples/smolvlm/main.rs
new file mode 100644
index 0000000..a6bf1e5
--- /dev/null
+++ b/examples/smolvlm/main.rs
@@ -0,0 +1,74 @@
+use anyhow::Result;
+use usls::{models::SmolVLM, DataLoader, Options, Scale};
+
+#[derive(argh::FromArgs)]
+/// Example
+struct Args {
+ /// device
+ #[argh(option, default = "String::from(\"cpu:0\")")]
+ device: String,
+
+ /// source image
+ #[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")]
+ source: Vec,
+
+ /// promt
+ #[argh(option, default = "String::from(\"Can you describe this image?\")")]
+ prompt: String,
+
+ /// scale
+ #[argh(option, default = "String::from(\"256m\")")]
+ scale: String,
+}
+
+fn main() -> Result<()> {
+ tracing_subscriber::fmt()
+ .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
+ .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
+ .init();
+ let args: Args = argh::from_env();
+
+ // build model
+ let (options_vision_encoder, options_text_embed, options_decode) =
+ match args.scale.as_str().try_into()? {
+ Scale::Million(256.) => (
+ Options::smolvlm_vision_256m(),
+ Options::smolvlm_text_embed_256m(),
+ Options::smolvlm_decoder_256m(),
+ ),
+ Scale::Million(500.) => (
+ Options::smolvlm_vision_500m(),
+ Options::smolvlm_text_embed_500m(),
+ Options::smolvlm_decoder_500m(),
+ ),
+ _ => unimplemented!(),
+ };
+
+ let mut model = SmolVLM::new(
+ options_vision_encoder
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ options_text_embed
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ options_decode
+ .with_model_device(args.device.as_str().try_into()?)
+ .commit()?,
+ )?;
+
+ // load images
+ let xs = DataLoader::try_read_batch(&args.source)?;
+
+ // run
+ let ys = model.forward(&xs, &args.prompt)?;
+
+ for y in ys.iter() {
+ if let Some(texts) = y.texts() {
+ for text in texts {
+ println!("[User]: {}\n\n[Assistant]:{}", args.prompt, text);
+ }
+ }
+ }
+
+ Ok(())
+}
diff --git a/src/models/mod.rs b/src/models/mod.rs
index b80d974..9f260e1 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -25,6 +25,7 @@ mod rtmo;
mod sam;
mod sapiens;
mod slanet;
+mod smolvlm;
mod svtr;
mod trocr;
mod yolo;
@@ -48,6 +49,7 @@ pub use rtmo::*;
pub use sam::*;
pub use sapiens::*;
pub use slanet::*;
+pub use smolvlm::*;
pub use svtr::*;
pub use trocr::*;
pub use yolo::*;
diff --git a/src/models/smolvlm/README.md b/src/models/smolvlm/README.md
new file mode 100644
index 0000000..d134084
--- /dev/null
+++ b/src/models/smolvlm/README.md
@@ -0,0 +1,11 @@
+# SmolVLM - small yet mighty Vision Language Model
+
+## Official Repository
+
+The official repository can be found on:
+* [SmolVLM-256M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct)
+* [SmolVLM-500M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct)
+
+## Example
+
+Refer to the [example](../../../examples/smolvlm)
diff --git a/src/models/smolvlm/config.rs b/src/models/smolvlm/config.rs
new file mode 100644
index 0000000..1aa4c9c
--- /dev/null
+++ b/src/models/smolvlm/config.rs
@@ -0,0 +1,58 @@
+/// Model configuration for `SmolVLM`
+impl crate::Options {
+ pub fn smolvlm() -> Self {
+ Self::default()
+ .with_batch_size(1)
+ .with_model_name("smolvlm")
+ .with_model_num_dry_run(3)
+ }
+
+ pub fn smolvlm_vision() -> Self {
+ Self::smolvlm()
+ .with_model_kind(crate::Kind::Vision)
+ .with_image_mean(&[0.5, 0.5, 0.5])
+ .with_image_std(&[0.5, 0.5, 0.5])
+ .with_resize_filter("lanczos3")
+ .with_normalize(true)
+ }
+
+ pub fn smolvlm_text() -> Self {
+ Self::smolvlm().with_model_kind(crate::Kind::Language)
+ }
+
+ pub fn smolvlm_vision_256m() -> Self {
+ Self::smolvlm_vision()
+ .with_model_scale(crate::Scale::Million(256.))
+ .with_model_file("256m-vision-encoder.onnx")
+ }
+
+ pub fn smolvlm_text_embed_256m() -> Self {
+ Self::smolvlm_text()
+ .with_model_scale(crate::Scale::Million(256.))
+ .with_model_file("256m-embed-tokens.onnx")
+ }
+
+ pub fn smolvlm_decoder_256m() -> Self {
+ Self::smolvlm_text()
+ .with_model_scale(crate::Scale::Million(256.))
+ .with_model_file("256m-decoder-model-merged.onnx")
+ }
+
+ pub fn smolvlm_vision_500m() -> Self {
+ Self::smolvlm_vision()
+ .with_model_scale(crate::Scale::Million(500.))
+ .with_model_file("500m-vision-encoder.onnx")
+ }
+
+ pub fn smolvlm_text_embed_500m() -> Self {
+ Self::smolvlm_text()
+ .with_model_scale(crate::Scale::Million(500.))
+ .with_model_file("500m-embed-tokens.onnx")
+ }
+
+ pub fn smolvlm_decoder_500m() -> Self {
+ Self::smolvlm_text()
+ .with_model_scale(crate::Scale::Million(500.))
+ .with_model_file("500m-decoder-model-merged.onnx")
+ }
+}
diff --git a/src/models/smolvlm/impl.rs b/src/models/smolvlm/impl.rs
new file mode 100644
index 0000000..8d3ed85
--- /dev/null
+++ b/src/models/smolvlm/impl.rs
@@ -0,0 +1,323 @@
+use aksr::Builder;
+use anyhow::Result;
+use image::{DynamicImage, GenericImageView};
+use ndarray::s;
+
+use crate::{
+ models::BaseModelTextual, Engine, LogitsSampler, Options, Processor, Scale, Ts, Xs, Ys, X, Y,
+};
+
+#[derive(Debug, Builder)]
+pub struct SmolVLM {
+ vision: VisionEncoder,
+ text_embed: BaseModelTextual,
+ decoder: BaseModelTextual,
+ scale: Scale,
+ image_token: String,
+ global_img_token: String,
+ fake_image_token: String,
+ bos_token: String,
+ eos_token: String,
+ eos_token_id: u32,
+ image_token_id: u32,
+ max_length: usize,
+ image_seq_len: usize,
+ num_hidden_layers: usize,
+ head_dim: usize,
+ num_key_value_heads: usize,
+}
+
+impl SmolVLM {
+ pub fn new(
+ options_vision_encoder: Options,
+ options_text_embed: Options,
+ options_decode: Options,
+ ) -> Result {
+ let vision = VisionEncoder::new(options_vision_encoder)?;
+ let text_embed = BaseModelTextual::new(options_text_embed)?;
+ let decoder = BaseModelTextual::new(options_decode)?;
+ let fake_image_token = "".to_string();
+ let image_token = "".to_string();
+ let global_img_token = "".to_string();
+ let bos_token = "<|im_start|>".to_string();
+ let eos_token = "".to_string();
+ let eos_token_id = 2;
+ let image_token_id = 49190;
+ let image_seq_len = 64;
+ let max_length = 1024;
+ let (num_hidden_layers, head_dim, num_key_value_heads) = match decoder.scale() {
+ Some(Scale::Million(256.)) => (30, 64, 3),
+ Some(Scale::Million(500.)) => (32, 64, 5),
+ _ => unimplemented!(),
+ };
+ let scale = *decoder.scale().unwrap();
+
+ Ok(Self {
+ vision,
+ text_embed,
+ decoder,
+ scale,
+ max_length,
+ eos_token_id,
+ image_token,
+ image_token_id,
+ global_img_token,
+ fake_image_token,
+ num_hidden_layers,
+ head_dim,
+ num_key_value_heads,
+ bos_token,
+ eos_token,
+ image_seq_len,
+ })
+ }
+
+ pub fn forward(&mut self, images: &[DynamicImage], text: &str) -> Result {
+ let mut ys: Vec = Vec::new();
+ for image in images.iter() {
+ let y = self.generate_one(image, text)?;
+ ys.push(Y::default().with_texts(&[y.into()]));
+ }
+
+ Ok(ys.into())
+ }
+
+ fn generate_one(&mut self, image: &DynamicImage, text: &str) -> Result {
+ let bs = 1; // TODO
+
+ // patches and pixel_attention_mask
+ let (patches, nw_nh) = self.vision.process_one(image)?;
+ let dims = patches.dims();
+ let pixel_attention_mask = X::ones(&[dims[0], dims[1], dims[3], dims[4]]);
+
+ // input ids
+ let prompt = self.image_prompt_string(nw_nh, text);
+ let mut input_ids: Vec = self.text_embed.processor().encode_text_ids(&prompt, true)?;
+
+ // position ids
+ let mut position_ids = X::from(
+ (1..input_ids.len() + 1)
+ .map(|x| x as f32)
+ .collect::>(),
+ )
+ .insert_axis(0)?;
+
+ // past key_values
+ let mut past_key_values = vec![
+ X::zeros(&[bs, self.num_key_value_heads, 0, self.head_dim]);
+ self.num_hidden_layers * 2
+ ];
+
+ // generate
+ let logits_sampler = LogitsSampler::new();
+ let mut token_ids: Vec = vec![];
+ for ii in 0..self.max_length {
+ // inputs embeds
+ let input_ids_x = X::from(input_ids.clone()).insert_axis(0)?;
+ let mut inputs_embeds =
+ self.text_embed.inference(input_ids_x.clone().into())?[0].clone();
+
+ // encode image and merge
+ if ii == 0 {
+ let image_features = self.vision.inference(Xs::from(vec![
+ patches.clone(),
+ pixel_attention_mask.clone(),
+ ]))?[0]
+ .clone();
+ let dims = image_features.dim();
+ let image_features = image_features.to_shape((dims[0] * dims[1], dims[2]))?;
+
+ // merge
+ let mut r = 0;
+ for (i, &token_id) in input_ids_x.indexed_iter() {
+ if token_id == self.image_token_id as f32 {
+ inputs_embeds
+ .0
+ .slice_mut(s![0, i[1], ..])
+ .assign(&image_features.slice(s![r, ..]));
+ r += 1;
+ }
+ }
+ }
+
+ // inputs
+ let mut xs = vec![
+ inputs_embeds.clone(),
+ X::ones_like(&input_ids_x),
+ position_ids.clone(),
+ ];
+ for i in 0..self.num_hidden_layers {
+ xs.push(past_key_values[i * 2].clone());
+ xs.push(past_key_values[i * 2 + 1].clone());
+ }
+
+ // decode
+ let decoder_outputs = self.decoder.inference(xs.into())?;
+ let logits = &decoder_outputs[0];
+ past_key_values = (1..decoder_outputs.len())
+ .step_by(2)
+ .flat_map(|i| [i, i + 1])
+ .map(|i| decoder_outputs[i].clone())
+ .collect();
+ let token_id = logits_sampler.decode(
+ &logits
+ .slice(s![0, -1, ..])
+ .into_owned()
+ .into_raw_vec_and_offset()
+ .0,
+ )?;
+
+ // early return
+ if token_id == self.eos_token_id {
+ break;
+ } else {
+ token_ids.push(token_id);
+ }
+
+ // update
+ input_ids = vec![token_id as f32];
+ position_ids = X::from(
+ position_ids
+ .slice(s![.., -1..])
+ .mapv(|x| x + 1.0)
+ .into_owned()
+ .into_dyn(),
+ );
+ }
+
+ // decode tokens
+ let text = self
+ .text_embed
+ .processor()
+ .decode_tokens(&token_ids, true)?;
+
+ Ok(text)
+ }
+
+ fn image_prompt_string(&self, nw_nh: (u32, u32), text: &str) -> String {
+ let (nw, nh) = nw_nh;
+ let image_tokens = self.image_token.repeat(self.image_seq_len);
+ let s1 = format!("{}User:", self.bos_token);
+ let s_global = format!(
+ "{}{}{}{}{}{}\nAssistant:",
+ self.fake_image_token,
+ self.global_img_token,
+ image_tokens,
+ self.fake_image_token,
+ text,
+ self.eos_token
+ );
+
+ match nw_nh {
+ (1, 1) => format!("{}{}", s1, s_global),
+ _ => {
+ let mut s = String::with_capacity(
+ s1.len()
+ + (nw as usize
+ * nh as usize
+ * (self.fake_image_token.len() + image_tokens.len() + 20))
+ + 10,
+ );
+ s.push_str(&s1);
+ // let mut s = s1;
+ for r in 1..=nh {
+ for c in 1..=nw {
+ s.push_str(&format!(
+ "{}{}",
+ self.fake_image_token, r, c, image_tokens
+ ));
+ }
+ s.push('\n');
+ }
+ format!("{}\n{}", s, s_global)
+ }
+ }
+ }
+}
+
+#[derive(Debug, Builder)]
+pub struct VisionEncoder {
+ engine: Engine,
+ num_patch: usize,
+ batch: usize,
+ width: usize,
+ height: usize,
+ processor: Processor,
+ ts: Ts,
+}
+
+impl VisionEncoder {
+ pub fn new(options: Options) -> Result {
+ let engine = options.to_engine()?;
+ let (batch, num_patch, height, width, ts) = (
+ engine.batch().opt(),
+ engine.inputs_minoptmax()[0][1].opt(),
+ engine.inputs_minoptmax()[0][3].opt(),
+ engine.inputs_minoptmax()[0][4].opt(),
+ engine.ts.clone(),
+ );
+ let processor = options
+ .to_processor()?
+ .with_image_width(width as _)
+ .with_image_height(height as _);
+
+ Ok(Self {
+ engine,
+ num_patch,
+ batch,
+ width,
+ height,
+ processor,
+ ts,
+ })
+ }
+
+ fn create_patches(
+ image: &DynamicImage,
+ patch_size: (u32, u32),
+ ) -> (Vec, (u32, u32)) {
+ let mut patches = vec![];
+ let image_rgb8 = image.to_rgb8();
+ let (image_width, image_height) = image_rgb8.dimensions();
+ let (patch_width, patch_height) = patch_size;
+
+ let (nw, nh) = if image_width > patch_width || image_height > patch_height {
+ // calculate the number of splits
+ let nw = image_width.div_ceil(patch_width);
+ let nh = image_height.div_ceil(patch_height);
+
+ // calculate the optimal width and height for the sub-images
+ let optimal_height = image_height.div_ceil(nh);
+ let optimal_width = image_width.div_ceil(nw);
+
+ // SubImage
+ for r in 0..nh {
+ for c in 0..nw {
+ let x0 = c * optimal_width;
+ let y0 = r * optimal_height;
+ let x1 = (x0 + optimal_width).min(image_width);
+ let y1 = (y0 + optimal_height).min(image_height);
+ let sub_image = image_rgb8.view(x0, y0, x1 - x0, y1 - y0).to_image();
+ patches.push(DynamicImage::from(sub_image));
+ }
+ }
+ (nw, nh)
+ } else {
+ (1, 1)
+ };
+ patches.push(image.clone());
+
+ (patches, (nw, nh))
+ }
+
+ pub fn inference(&mut self, xs: Xs) -> Result {
+ self.engine.run(xs)
+ }
+
+ pub fn process_one(&mut self, x: &DynamicImage) -> Result<(X, (u32, u32))> {
+ let (patches, nw_nh) = Self::create_patches(x, (self.width as _, self.height as _));
+ let patches = self.processor.process_images(&patches)?.insert_axis(0)?;
+
+ Ok((patches, (nw_nh)))
+ }
+}
diff --git a/src/models/smolvlm/mod.rs b/src/models/smolvlm/mod.rs
new file mode 100644
index 0000000..cc4987a
--- /dev/null
+++ b/src/models/smolvlm/mod.rs
@@ -0,0 +1,4 @@
+mod config;
+mod r#impl;
+
+pub use r#impl::SmolVLM;
diff --git a/src/xy/text.rs b/src/xy/text.rs
index 0c67b5f..7fc84e7 100644
--- a/src/xy/text.rs
+++ b/src/xy/text.rs
@@ -2,6 +2,12 @@
#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)]
pub struct Text(String);
+impl std::fmt::Display for Text {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
impl std::ops::Deref for Text {
type Target = String;