Add SmolVLM model (#65)

* Add SmolVLM model
This commit is contained in:
Jamjamjon
2025-02-08 00:28:35 +08:00
committed by GitHub
parent bdd77a6d21
commit e2347353ba
10 changed files with 489 additions and 2 deletions

View File

@ -20,7 +20,7 @@ anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
rand = { version = "0.8.5" }
chrono = { version = "0.4.30" }
tokenizers = { version = "0.15.2" }
tokenizers = { version = "0.21.0" }
log = { version = "0.4.22" }
indicatif = "0.17.8"
serde_json = "1.0"
@ -62,6 +62,6 @@ trt = [ "ort/tensorrt" ]
mps = [ "ort/coreml" ]
[profile.release]
# lto = true
lto = true
strip = true
panic = "abort"

View File

@ -88,6 +88,7 @@
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
</details>

View File

@ -0,0 +1,8 @@
## Quick Start
```shell
cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What's in it?"
cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What color is the car?"
cargo run -r --example smolvlm -- --scale 500m --source "images/slanted-text-number.jpg" --prompt "What are these numbers?"
cargo run -r --example smolvlm -- --scale 256m --source "images/Statue-of-Liberty-Island-New-York-Bay.jpg" --prompt "Can you describe this image?"
```

74
examples/smolvlm/main.rs Normal file
View File

@ -0,0 +1,74 @@
use anyhow::Result;
use usls::{models::SmolVLM, DataLoader, Options, Scale};
#[derive(argh::FromArgs)]
/// Example
struct Args {
/// device
#[argh(option, default = "String::from(\"cpu:0\")")]
device: String,
/// source image
#[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")]
source: Vec<String>,
/// promt
#[argh(option, default = "String::from(\"Can you describe this image?\")")]
prompt: String,
/// scale
#[argh(option, default = "String::from(\"256m\")")]
scale: String,
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
.init();
let args: Args = argh::from_env();
// build model
let (options_vision_encoder, options_text_embed, options_decode) =
match args.scale.as_str().try_into()? {
Scale::Million(256.) => (
Options::smolvlm_vision_256m(),
Options::smolvlm_text_embed_256m(),
Options::smolvlm_decoder_256m(),
),
Scale::Million(500.) => (
Options::smolvlm_vision_500m(),
Options::smolvlm_text_embed_500m(),
Options::smolvlm_decoder_500m(),
),
_ => unimplemented!(),
};
let mut model = SmolVLM::new(
options_vision_encoder
.with_model_device(args.device.as_str().try_into()?)
.commit()?,
options_text_embed
.with_model_device(args.device.as_str().try_into()?)
.commit()?,
options_decode
.with_model_device(args.device.as_str().try_into()?)
.commit()?,
)?;
// load images
let xs = DataLoader::try_read_batch(&args.source)?;
// run
let ys = model.forward(&xs, &args.prompt)?;
for y in ys.iter() {
if let Some(texts) = y.texts() {
for text in texts {
println!("[User]: {}\n\n[Assistant]:{}", args.prompt, text);
}
}
}
Ok(())
}

View File

@ -25,6 +25,7 @@ mod rtmo;
mod sam;
mod sapiens;
mod slanet;
mod smolvlm;
mod svtr;
mod trocr;
mod yolo;
@ -48,6 +49,7 @@ pub use rtmo::*;
pub use sam::*;
pub use sapiens::*;
pub use slanet::*;
pub use smolvlm::*;
pub use svtr::*;
pub use trocr::*;
pub use yolo::*;

View File

@ -0,0 +1,11 @@
# SmolVLM - small yet mighty Vision Language Model
## Official Repository
The official repository can be found on:
* [SmolVLM-256M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct)
* [SmolVLM-500M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct)
## Example
Refer to the [example](../../../examples/smolvlm)

View File

@ -0,0 +1,58 @@
/// Model configuration for `SmolVLM`
impl crate::Options {
pub fn smolvlm() -> Self {
Self::default()
.with_batch_size(1)
.with_model_name("smolvlm")
.with_model_num_dry_run(3)
}
pub fn smolvlm_vision() -> Self {
Self::smolvlm()
.with_model_kind(crate::Kind::Vision)
.with_image_mean(&[0.5, 0.5, 0.5])
.with_image_std(&[0.5, 0.5, 0.5])
.with_resize_filter("lanczos3")
.with_normalize(true)
}
pub fn smolvlm_text() -> Self {
Self::smolvlm().with_model_kind(crate::Kind::Language)
}
pub fn smolvlm_vision_256m() -> Self {
Self::smolvlm_vision()
.with_model_scale(crate::Scale::Million(256.))
.with_model_file("256m-vision-encoder.onnx")
}
pub fn smolvlm_text_embed_256m() -> Self {
Self::smolvlm_text()
.with_model_scale(crate::Scale::Million(256.))
.with_model_file("256m-embed-tokens.onnx")
}
pub fn smolvlm_decoder_256m() -> Self {
Self::smolvlm_text()
.with_model_scale(crate::Scale::Million(256.))
.with_model_file("256m-decoder-model-merged.onnx")
}
pub fn smolvlm_vision_500m() -> Self {
Self::smolvlm_vision()
.with_model_scale(crate::Scale::Million(500.))
.with_model_file("500m-vision-encoder.onnx")
}
pub fn smolvlm_text_embed_500m() -> Self {
Self::smolvlm_text()
.with_model_scale(crate::Scale::Million(500.))
.with_model_file("500m-embed-tokens.onnx")
}
pub fn smolvlm_decoder_500m() -> Self {
Self::smolvlm_text()
.with_model_scale(crate::Scale::Million(500.))
.with_model_file("500m-decoder-model-merged.onnx")
}
}

323
src/models/smolvlm/impl.rs Normal file
View File

@ -0,0 +1,323 @@
use aksr::Builder;
use anyhow::Result;
use image::{DynamicImage, GenericImageView};
use ndarray::s;
use crate::{
models::BaseModelTextual, Engine, LogitsSampler, Options, Processor, Scale, Ts, Xs, Ys, X, Y,
};
#[derive(Debug, Builder)]
pub struct SmolVLM {
vision: VisionEncoder,
text_embed: BaseModelTextual,
decoder: BaseModelTextual,
scale: Scale,
image_token: String,
global_img_token: String,
fake_image_token: String,
bos_token: String,
eos_token: String,
eos_token_id: u32,
image_token_id: u32,
max_length: usize,
image_seq_len: usize,
num_hidden_layers: usize,
head_dim: usize,
num_key_value_heads: usize,
}
impl SmolVLM {
pub fn new(
options_vision_encoder: Options,
options_text_embed: Options,
options_decode: Options,
) -> Result<Self> {
let vision = VisionEncoder::new(options_vision_encoder)?;
let text_embed = BaseModelTextual::new(options_text_embed)?;
let decoder = BaseModelTextual::new(options_decode)?;
let fake_image_token = "<fake_token_around_image>".to_string();
let image_token = "<image>".to_string();
let global_img_token = "<global-img>".to_string();
let bos_token = "<|im_start|>".to_string();
let eos_token = "<end_of_utterance>".to_string();
let eos_token_id = 2;
let image_token_id = 49190;
let image_seq_len = 64;
let max_length = 1024;
let (num_hidden_layers, head_dim, num_key_value_heads) = match decoder.scale() {
Some(Scale::Million(256.)) => (30, 64, 3),
Some(Scale::Million(500.)) => (32, 64, 5),
_ => unimplemented!(),
};
let scale = *decoder.scale().unwrap();
Ok(Self {
vision,
text_embed,
decoder,
scale,
max_length,
eos_token_id,
image_token,
image_token_id,
global_img_token,
fake_image_token,
num_hidden_layers,
head_dim,
num_key_value_heads,
bos_token,
eos_token,
image_seq_len,
})
}
pub fn forward(&mut self, images: &[DynamicImage], text: &str) -> Result<Ys> {
let mut ys: Vec<Y> = Vec::new();
for image in images.iter() {
let y = self.generate_one(image, text)?;
ys.push(Y::default().with_texts(&[y.into()]));
}
Ok(ys.into())
}
fn generate_one(&mut self, image: &DynamicImage, text: &str) -> Result<String> {
let bs = 1; // TODO
// patches and pixel_attention_mask
let (patches, nw_nh) = self.vision.process_one(image)?;
let dims = patches.dims();
let pixel_attention_mask = X::ones(&[dims[0], dims[1], dims[3], dims[4]]);
// input ids
let prompt = self.image_prompt_string(nw_nh, text);
let mut input_ids: Vec<f32> = self.text_embed.processor().encode_text_ids(&prompt, true)?;
// position ids
let mut position_ids = X::from(
(1..input_ids.len() + 1)
.map(|x| x as f32)
.collect::<Vec<f32>>(),
)
.insert_axis(0)?;
// past key_values
let mut past_key_values = vec![
X::zeros(&[bs, self.num_key_value_heads, 0, self.head_dim]);
self.num_hidden_layers * 2
];
// generate
let logits_sampler = LogitsSampler::new();
let mut token_ids: Vec<u32> = vec![];
for ii in 0..self.max_length {
// inputs embeds
let input_ids_x = X::from(input_ids.clone()).insert_axis(0)?;
let mut inputs_embeds =
self.text_embed.inference(input_ids_x.clone().into())?[0].clone();
// encode image and merge
if ii == 0 {
let image_features = self.vision.inference(Xs::from(vec![
patches.clone(),
pixel_attention_mask.clone(),
]))?[0]
.clone();
let dims = image_features.dim();
let image_features = image_features.to_shape((dims[0] * dims[1], dims[2]))?;
// merge
let mut r = 0;
for (i, &token_id) in input_ids_x.indexed_iter() {
if token_id == self.image_token_id as f32 {
inputs_embeds
.0
.slice_mut(s![0, i[1], ..])
.assign(&image_features.slice(s![r, ..]));
r += 1;
}
}
}
// inputs
let mut xs = vec![
inputs_embeds.clone(),
X::ones_like(&input_ids_x),
position_ids.clone(),
];
for i in 0..self.num_hidden_layers {
xs.push(past_key_values[i * 2].clone());
xs.push(past_key_values[i * 2 + 1].clone());
}
// decode
let decoder_outputs = self.decoder.inference(xs.into())?;
let logits = &decoder_outputs[0];
past_key_values = (1..decoder_outputs.len())
.step_by(2)
.flat_map(|i| [i, i + 1])
.map(|i| decoder_outputs[i].clone())
.collect();
let token_id = logits_sampler.decode(
&logits
.slice(s![0, -1, ..])
.into_owned()
.into_raw_vec_and_offset()
.0,
)?;
// early return
if token_id == self.eos_token_id {
break;
} else {
token_ids.push(token_id);
}
// update
input_ids = vec![token_id as f32];
position_ids = X::from(
position_ids
.slice(s![.., -1..])
.mapv(|x| x + 1.0)
.into_owned()
.into_dyn(),
);
}
// decode tokens
let text = self
.text_embed
.processor()
.decode_tokens(&token_ids, true)?;
Ok(text)
}
fn image_prompt_string(&self, nw_nh: (u32, u32), text: &str) -> String {
let (nw, nh) = nw_nh;
let image_tokens = self.image_token.repeat(self.image_seq_len);
let s1 = format!("{}User:", self.bos_token);
let s_global = format!(
"{}{}{}{}{}{}\nAssistant:",
self.fake_image_token,
self.global_img_token,
image_tokens,
self.fake_image_token,
text,
self.eos_token
);
match nw_nh {
(1, 1) => format!("{}{}", s1, s_global),
_ => {
let mut s = String::with_capacity(
s1.len()
+ (nw as usize
* nh as usize
* (self.fake_image_token.len() + image_tokens.len() + 20))
+ 10,
);
s.push_str(&s1);
// let mut s = s1;
for r in 1..=nh {
for c in 1..=nw {
s.push_str(&format!(
"{}<row_{}_col_{}>{}",
self.fake_image_token, r, c, image_tokens
));
}
s.push('\n');
}
format!("{}\n{}", s, s_global)
}
}
}
}
#[derive(Debug, Builder)]
pub struct VisionEncoder {
engine: Engine,
num_patch: usize,
batch: usize,
width: usize,
height: usize,
processor: Processor,
ts: Ts,
}
impl VisionEncoder {
pub fn new(options: Options) -> Result<Self> {
let engine = options.to_engine()?;
let (batch, num_patch, height, width, ts) = (
engine.batch().opt(),
engine.inputs_minoptmax()[0][1].opt(),
engine.inputs_minoptmax()[0][3].opt(),
engine.inputs_minoptmax()[0][4].opt(),
engine.ts.clone(),
);
let processor = options
.to_processor()?
.with_image_width(width as _)
.with_image_height(height as _);
Ok(Self {
engine,
num_patch,
batch,
width,
height,
processor,
ts,
})
}
fn create_patches(
image: &DynamicImage,
patch_size: (u32, u32),
) -> (Vec<DynamicImage>, (u32, u32)) {
let mut patches = vec![];
let image_rgb8 = image.to_rgb8();
let (image_width, image_height) = image_rgb8.dimensions();
let (patch_width, patch_height) = patch_size;
let (nw, nh) = if image_width > patch_width || image_height > patch_height {
// calculate the number of splits
let nw = image_width.div_ceil(patch_width);
let nh = image_height.div_ceil(patch_height);
// calculate the optimal width and height for the sub-images
let optimal_height = image_height.div_ceil(nh);
let optimal_width = image_width.div_ceil(nw);
// SubImage
for r in 0..nh {
for c in 0..nw {
let x0 = c * optimal_width;
let y0 = r * optimal_height;
let x1 = (x0 + optimal_width).min(image_width);
let y1 = (y0 + optimal_height).min(image_height);
let sub_image = image_rgb8.view(x0, y0, x1 - x0, y1 - y0).to_image();
patches.push(DynamicImage::from(sub_image));
}
}
(nw, nh)
} else {
(1, 1)
};
patches.push(image.clone());
(patches, (nw, nh))
}
pub fn inference(&mut self, xs: Xs) -> Result<Xs> {
self.engine.run(xs)
}
pub fn process_one(&mut self, x: &DynamicImage) -> Result<(X, (u32, u32))> {
let (patches, nw_nh) = Self::create_patches(x, (self.width as _, self.height as _));
let patches = self.processor.process_images(&patches)?.insert_axis(0)?;
Ok((patches, (nw_nh)))
}
}

View File

@ -0,0 +1,4 @@
mod config;
mod r#impl;
pub use r#impl::SmolVLM;

View File

@ -2,6 +2,12 @@
#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)]
pub struct Text(String);
impl std::fmt::Display for Text {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::ops::Deref for Text {
type Target = String;