mirror of
https://github.com/mii443/usls.git
synced 2025-08-22 15:45:41 +00:00
@ -20,7 +20,7 @@ anyhow = { version = "1.0.75" }
|
|||||||
regex = { version = "1.5.4" }
|
regex = { version = "1.5.4" }
|
||||||
rand = { version = "0.8.5" }
|
rand = { version = "0.8.5" }
|
||||||
chrono = { version = "0.4.30" }
|
chrono = { version = "0.4.30" }
|
||||||
tokenizers = { version = "0.15.2" }
|
tokenizers = { version = "0.21.0" }
|
||||||
log = { version = "0.4.22" }
|
log = { version = "0.4.22" }
|
||||||
indicatif = "0.17.8"
|
indicatif = "0.17.8"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
@ -62,6 +62,6 @@ trt = [ "ort/tensorrt" ]
|
|||||||
mps = [ "ort/coreml" ]
|
mps = [ "ort/coreml" ]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
# lto = true
|
lto = true
|
||||||
strip = true
|
strip = true
|
||||||
panic = "abort"
|
panic = "abort"
|
||||||
|
@ -88,6 +88,7 @@
|
|||||||
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
|
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | ✅ | | |
|
||||||
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
|
| [Moondream2](https://github.com/vikhyat/moondream/tree/main) | Open-Set Object Detection<br />Open-Set Keypoints Detection<br />Image Caption<br />Visual Question Answering | [demo](examples/moondream2) | ✅ | ✅ | ✅ | | |
|
||||||
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
|
| [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Set Object Detection | [demo](examples/owlv2) | ✅ | ✅ | ✅ | | |
|
||||||
|
| [SmolVLM(256M, 500M)](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | Visual Question Answering | [demo](examples/smolvlm) | ✅ | ✅ | ✅ | | |
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
8
examples/smolvlm/README.md
Normal file
8
examples/smolvlm/README.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What's in it?"
|
||||||
|
cargo run -r --example smolvlm -- --scale 500m --source "images/green-car.jpg" --prompt "What color is the car?"
|
||||||
|
cargo run -r --example smolvlm -- --scale 500m --source "images/slanted-text-number.jpg" --prompt "What are these numbers?"
|
||||||
|
cargo run -r --example smolvlm -- --scale 256m --source "images/Statue-of-Liberty-Island-New-York-Bay.jpg" --prompt "Can you describe this image?"
|
||||||
|
```
|
74
examples/smolvlm/main.rs
Normal file
74
examples/smolvlm/main.rs
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use usls::{models::SmolVLM, DataLoader, Options, Scale};
|
||||||
|
|
||||||
|
#[derive(argh::FromArgs)]
|
||||||
|
/// Example
|
||||||
|
struct Args {
|
||||||
|
/// device
|
||||||
|
#[argh(option, default = "String::from(\"cpu:0\")")]
|
||||||
|
device: String,
|
||||||
|
|
||||||
|
/// source image
|
||||||
|
#[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")]
|
||||||
|
source: Vec<String>,
|
||||||
|
|
||||||
|
/// promt
|
||||||
|
#[argh(option, default = "String::from(\"Can you describe this image?\")")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// scale
|
||||||
|
#[argh(option, default = "String::from(\"256m\")")]
|
||||||
|
scale: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||||
|
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339())
|
||||||
|
.init();
|
||||||
|
let args: Args = argh::from_env();
|
||||||
|
|
||||||
|
// build model
|
||||||
|
let (options_vision_encoder, options_text_embed, options_decode) =
|
||||||
|
match args.scale.as_str().try_into()? {
|
||||||
|
Scale::Million(256.) => (
|
||||||
|
Options::smolvlm_vision_256m(),
|
||||||
|
Options::smolvlm_text_embed_256m(),
|
||||||
|
Options::smolvlm_decoder_256m(),
|
||||||
|
),
|
||||||
|
Scale::Million(500.) => (
|
||||||
|
Options::smolvlm_vision_500m(),
|
||||||
|
Options::smolvlm_text_embed_500m(),
|
||||||
|
Options::smolvlm_decoder_500m(),
|
||||||
|
),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut model = SmolVLM::new(
|
||||||
|
options_vision_encoder
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?,
|
||||||
|
options_text_embed
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?,
|
||||||
|
options_decode
|
||||||
|
.with_model_device(args.device.as_str().try_into()?)
|
||||||
|
.commit()?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// load images
|
||||||
|
let xs = DataLoader::try_read_batch(&args.source)?;
|
||||||
|
|
||||||
|
// run
|
||||||
|
let ys = model.forward(&xs, &args.prompt)?;
|
||||||
|
|
||||||
|
for y in ys.iter() {
|
||||||
|
if let Some(texts) = y.texts() {
|
||||||
|
for text in texts {
|
||||||
|
println!("[User]: {}\n\n[Assistant]:{}", args.prompt, text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -25,6 +25,7 @@ mod rtmo;
|
|||||||
mod sam;
|
mod sam;
|
||||||
mod sapiens;
|
mod sapiens;
|
||||||
mod slanet;
|
mod slanet;
|
||||||
|
mod smolvlm;
|
||||||
mod svtr;
|
mod svtr;
|
||||||
mod trocr;
|
mod trocr;
|
||||||
mod yolo;
|
mod yolo;
|
||||||
@ -48,6 +49,7 @@ pub use rtmo::*;
|
|||||||
pub use sam::*;
|
pub use sam::*;
|
||||||
pub use sapiens::*;
|
pub use sapiens::*;
|
||||||
pub use slanet::*;
|
pub use slanet::*;
|
||||||
|
pub use smolvlm::*;
|
||||||
pub use svtr::*;
|
pub use svtr::*;
|
||||||
pub use trocr::*;
|
pub use trocr::*;
|
||||||
pub use yolo::*;
|
pub use yolo::*;
|
||||||
|
11
src/models/smolvlm/README.md
Normal file
11
src/models/smolvlm/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# SmolVLM - small yet mighty Vision Language Model
|
||||||
|
|
||||||
|
## Official Repository
|
||||||
|
|
||||||
|
The official repository can be found on:
|
||||||
|
* [SmolVLM-256M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct)
|
||||||
|
* [SmolVLM-500M-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct)
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Refer to the [example](../../../examples/smolvlm)
|
58
src/models/smolvlm/config.rs
Normal file
58
src/models/smolvlm/config.rs
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
/// Model configuration for `SmolVLM`
|
||||||
|
impl crate::Options {
|
||||||
|
pub fn smolvlm() -> Self {
|
||||||
|
Self::default()
|
||||||
|
.with_batch_size(1)
|
||||||
|
.with_model_name("smolvlm")
|
||||||
|
.with_model_num_dry_run(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_vision() -> Self {
|
||||||
|
Self::smolvlm()
|
||||||
|
.with_model_kind(crate::Kind::Vision)
|
||||||
|
.with_image_mean(&[0.5, 0.5, 0.5])
|
||||||
|
.with_image_std(&[0.5, 0.5, 0.5])
|
||||||
|
.with_resize_filter("lanczos3")
|
||||||
|
.with_normalize(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_text() -> Self {
|
||||||
|
Self::smolvlm().with_model_kind(crate::Kind::Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_vision_256m() -> Self {
|
||||||
|
Self::smolvlm_vision()
|
||||||
|
.with_model_scale(crate::Scale::Million(256.))
|
||||||
|
.with_model_file("256m-vision-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_text_embed_256m() -> Self {
|
||||||
|
Self::smolvlm_text()
|
||||||
|
.with_model_scale(crate::Scale::Million(256.))
|
||||||
|
.with_model_file("256m-embed-tokens.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_decoder_256m() -> Self {
|
||||||
|
Self::smolvlm_text()
|
||||||
|
.with_model_scale(crate::Scale::Million(256.))
|
||||||
|
.with_model_file("256m-decoder-model-merged.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_vision_500m() -> Self {
|
||||||
|
Self::smolvlm_vision()
|
||||||
|
.with_model_scale(crate::Scale::Million(500.))
|
||||||
|
.with_model_file("500m-vision-encoder.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_text_embed_500m() -> Self {
|
||||||
|
Self::smolvlm_text()
|
||||||
|
.with_model_scale(crate::Scale::Million(500.))
|
||||||
|
.with_model_file("500m-embed-tokens.onnx")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn smolvlm_decoder_500m() -> Self {
|
||||||
|
Self::smolvlm_text()
|
||||||
|
.with_model_scale(crate::Scale::Million(500.))
|
||||||
|
.with_model_file("500m-decoder-model-merged.onnx")
|
||||||
|
}
|
||||||
|
}
|
323
src/models/smolvlm/impl.rs
Normal file
323
src/models/smolvlm/impl.rs
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
use aksr::Builder;
|
||||||
|
use anyhow::Result;
|
||||||
|
use image::{DynamicImage, GenericImageView};
|
||||||
|
use ndarray::s;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::BaseModelTextual, Engine, LogitsSampler, Options, Processor, Scale, Ts, Xs, Ys, X, Y,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Builder)]
|
||||||
|
pub struct SmolVLM {
|
||||||
|
vision: VisionEncoder,
|
||||||
|
text_embed: BaseModelTextual,
|
||||||
|
decoder: BaseModelTextual,
|
||||||
|
scale: Scale,
|
||||||
|
image_token: String,
|
||||||
|
global_img_token: String,
|
||||||
|
fake_image_token: String,
|
||||||
|
bos_token: String,
|
||||||
|
eos_token: String,
|
||||||
|
eos_token_id: u32,
|
||||||
|
image_token_id: u32,
|
||||||
|
max_length: usize,
|
||||||
|
image_seq_len: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
num_key_value_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SmolVLM {
|
||||||
|
pub fn new(
|
||||||
|
options_vision_encoder: Options,
|
||||||
|
options_text_embed: Options,
|
||||||
|
options_decode: Options,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vision = VisionEncoder::new(options_vision_encoder)?;
|
||||||
|
let text_embed = BaseModelTextual::new(options_text_embed)?;
|
||||||
|
let decoder = BaseModelTextual::new(options_decode)?;
|
||||||
|
let fake_image_token = "<fake_token_around_image>".to_string();
|
||||||
|
let image_token = "<image>".to_string();
|
||||||
|
let global_img_token = "<global-img>".to_string();
|
||||||
|
let bos_token = "<|im_start|>".to_string();
|
||||||
|
let eos_token = "<end_of_utterance>".to_string();
|
||||||
|
let eos_token_id = 2;
|
||||||
|
let image_token_id = 49190;
|
||||||
|
let image_seq_len = 64;
|
||||||
|
let max_length = 1024;
|
||||||
|
let (num_hidden_layers, head_dim, num_key_value_heads) = match decoder.scale() {
|
||||||
|
Some(Scale::Million(256.)) => (30, 64, 3),
|
||||||
|
Some(Scale::Million(500.)) => (32, 64, 5),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
};
|
||||||
|
let scale = *decoder.scale().unwrap();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
vision,
|
||||||
|
text_embed,
|
||||||
|
decoder,
|
||||||
|
scale,
|
||||||
|
max_length,
|
||||||
|
eos_token_id,
|
||||||
|
image_token,
|
||||||
|
image_token_id,
|
||||||
|
global_img_token,
|
||||||
|
fake_image_token,
|
||||||
|
num_hidden_layers,
|
||||||
|
head_dim,
|
||||||
|
num_key_value_heads,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
image_seq_len,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, images: &[DynamicImage], text: &str) -> Result<Ys> {
|
||||||
|
let mut ys: Vec<Y> = Vec::new();
|
||||||
|
for image in images.iter() {
|
||||||
|
let y = self.generate_one(image, text)?;
|
||||||
|
ys.push(Y::default().with_texts(&[y.into()]));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ys.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_one(&mut self, image: &DynamicImage, text: &str) -> Result<String> {
|
||||||
|
let bs = 1; // TODO
|
||||||
|
|
||||||
|
// patches and pixel_attention_mask
|
||||||
|
let (patches, nw_nh) = self.vision.process_one(image)?;
|
||||||
|
let dims = patches.dims();
|
||||||
|
let pixel_attention_mask = X::ones(&[dims[0], dims[1], dims[3], dims[4]]);
|
||||||
|
|
||||||
|
// input ids
|
||||||
|
let prompt = self.image_prompt_string(nw_nh, text);
|
||||||
|
let mut input_ids: Vec<f32> = self.text_embed.processor().encode_text_ids(&prompt, true)?;
|
||||||
|
|
||||||
|
// position ids
|
||||||
|
let mut position_ids = X::from(
|
||||||
|
(1..input_ids.len() + 1)
|
||||||
|
.map(|x| x as f32)
|
||||||
|
.collect::<Vec<f32>>(),
|
||||||
|
)
|
||||||
|
.insert_axis(0)?;
|
||||||
|
|
||||||
|
// past key_values
|
||||||
|
let mut past_key_values = vec![
|
||||||
|
X::zeros(&[bs, self.num_key_value_heads, 0, self.head_dim]);
|
||||||
|
self.num_hidden_layers * 2
|
||||||
|
];
|
||||||
|
|
||||||
|
// generate
|
||||||
|
let logits_sampler = LogitsSampler::new();
|
||||||
|
let mut token_ids: Vec<u32> = vec![];
|
||||||
|
for ii in 0..self.max_length {
|
||||||
|
// inputs embeds
|
||||||
|
let input_ids_x = X::from(input_ids.clone()).insert_axis(0)?;
|
||||||
|
let mut inputs_embeds =
|
||||||
|
self.text_embed.inference(input_ids_x.clone().into())?[0].clone();
|
||||||
|
|
||||||
|
// encode image and merge
|
||||||
|
if ii == 0 {
|
||||||
|
let image_features = self.vision.inference(Xs::from(vec![
|
||||||
|
patches.clone(),
|
||||||
|
pixel_attention_mask.clone(),
|
||||||
|
]))?[0]
|
||||||
|
.clone();
|
||||||
|
let dims = image_features.dim();
|
||||||
|
let image_features = image_features.to_shape((dims[0] * dims[1], dims[2]))?;
|
||||||
|
|
||||||
|
// merge
|
||||||
|
let mut r = 0;
|
||||||
|
for (i, &token_id) in input_ids_x.indexed_iter() {
|
||||||
|
if token_id == self.image_token_id as f32 {
|
||||||
|
inputs_embeds
|
||||||
|
.0
|
||||||
|
.slice_mut(s![0, i[1], ..])
|
||||||
|
.assign(&image_features.slice(s![r, ..]));
|
||||||
|
r += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inputs
|
||||||
|
let mut xs = vec![
|
||||||
|
inputs_embeds.clone(),
|
||||||
|
X::ones_like(&input_ids_x),
|
||||||
|
position_ids.clone(),
|
||||||
|
];
|
||||||
|
for i in 0..self.num_hidden_layers {
|
||||||
|
xs.push(past_key_values[i * 2].clone());
|
||||||
|
xs.push(past_key_values[i * 2 + 1].clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode
|
||||||
|
let decoder_outputs = self.decoder.inference(xs.into())?;
|
||||||
|
let logits = &decoder_outputs[0];
|
||||||
|
past_key_values = (1..decoder_outputs.len())
|
||||||
|
.step_by(2)
|
||||||
|
.flat_map(|i| [i, i + 1])
|
||||||
|
.map(|i| decoder_outputs[i].clone())
|
||||||
|
.collect();
|
||||||
|
let token_id = logits_sampler.decode(
|
||||||
|
&logits
|
||||||
|
.slice(s![0, -1, ..])
|
||||||
|
.into_owned()
|
||||||
|
.into_raw_vec_and_offset()
|
||||||
|
.0,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// early return
|
||||||
|
if token_id == self.eos_token_id {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
token_ids.push(token_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// update
|
||||||
|
input_ids = vec![token_id as f32];
|
||||||
|
position_ids = X::from(
|
||||||
|
position_ids
|
||||||
|
.slice(s![.., -1..])
|
||||||
|
.mapv(|x| x + 1.0)
|
||||||
|
.into_owned()
|
||||||
|
.into_dyn(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode tokens
|
||||||
|
let text = self
|
||||||
|
.text_embed
|
||||||
|
.processor()
|
||||||
|
.decode_tokens(&token_ids, true)?;
|
||||||
|
|
||||||
|
Ok(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn image_prompt_string(&self, nw_nh: (u32, u32), text: &str) -> String {
|
||||||
|
let (nw, nh) = nw_nh;
|
||||||
|
let image_tokens = self.image_token.repeat(self.image_seq_len);
|
||||||
|
let s1 = format!("{}User:", self.bos_token);
|
||||||
|
let s_global = format!(
|
||||||
|
"{}{}{}{}{}{}\nAssistant:",
|
||||||
|
self.fake_image_token,
|
||||||
|
self.global_img_token,
|
||||||
|
image_tokens,
|
||||||
|
self.fake_image_token,
|
||||||
|
text,
|
||||||
|
self.eos_token
|
||||||
|
);
|
||||||
|
|
||||||
|
match nw_nh {
|
||||||
|
(1, 1) => format!("{}{}", s1, s_global),
|
||||||
|
_ => {
|
||||||
|
let mut s = String::with_capacity(
|
||||||
|
s1.len()
|
||||||
|
+ (nw as usize
|
||||||
|
* nh as usize
|
||||||
|
* (self.fake_image_token.len() + image_tokens.len() + 20))
|
||||||
|
+ 10,
|
||||||
|
);
|
||||||
|
s.push_str(&s1);
|
||||||
|
// let mut s = s1;
|
||||||
|
for r in 1..=nh {
|
||||||
|
for c in 1..=nw {
|
||||||
|
s.push_str(&format!(
|
||||||
|
"{}<row_{}_col_{}>{}",
|
||||||
|
self.fake_image_token, r, c, image_tokens
|
||||||
|
));
|
||||||
|
}
|
||||||
|
s.push('\n');
|
||||||
|
}
|
||||||
|
format!("{}\n{}", s, s_global)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Builder)]
|
||||||
|
pub struct VisionEncoder {
|
||||||
|
engine: Engine,
|
||||||
|
num_patch: usize,
|
||||||
|
batch: usize,
|
||||||
|
width: usize,
|
||||||
|
height: usize,
|
||||||
|
processor: Processor,
|
||||||
|
ts: Ts,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionEncoder {
|
||||||
|
pub fn new(options: Options) -> Result<Self> {
|
||||||
|
let engine = options.to_engine()?;
|
||||||
|
let (batch, num_patch, height, width, ts) = (
|
||||||
|
engine.batch().opt(),
|
||||||
|
engine.inputs_minoptmax()[0][1].opt(),
|
||||||
|
engine.inputs_minoptmax()[0][3].opt(),
|
||||||
|
engine.inputs_minoptmax()[0][4].opt(),
|
||||||
|
engine.ts.clone(),
|
||||||
|
);
|
||||||
|
let processor = options
|
||||||
|
.to_processor()?
|
||||||
|
.with_image_width(width as _)
|
||||||
|
.with_image_height(height as _);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
engine,
|
||||||
|
num_patch,
|
||||||
|
batch,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
processor,
|
||||||
|
ts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_patches(
|
||||||
|
image: &DynamicImage,
|
||||||
|
patch_size: (u32, u32),
|
||||||
|
) -> (Vec<DynamicImage>, (u32, u32)) {
|
||||||
|
let mut patches = vec![];
|
||||||
|
let image_rgb8 = image.to_rgb8();
|
||||||
|
let (image_width, image_height) = image_rgb8.dimensions();
|
||||||
|
let (patch_width, patch_height) = patch_size;
|
||||||
|
|
||||||
|
let (nw, nh) = if image_width > patch_width || image_height > patch_height {
|
||||||
|
// calculate the number of splits
|
||||||
|
let nw = image_width.div_ceil(patch_width);
|
||||||
|
let nh = image_height.div_ceil(patch_height);
|
||||||
|
|
||||||
|
// calculate the optimal width and height for the sub-images
|
||||||
|
let optimal_height = image_height.div_ceil(nh);
|
||||||
|
let optimal_width = image_width.div_ceil(nw);
|
||||||
|
|
||||||
|
// SubImage
|
||||||
|
for r in 0..nh {
|
||||||
|
for c in 0..nw {
|
||||||
|
let x0 = c * optimal_width;
|
||||||
|
let y0 = r * optimal_height;
|
||||||
|
let x1 = (x0 + optimal_width).min(image_width);
|
||||||
|
let y1 = (y0 + optimal_height).min(image_height);
|
||||||
|
let sub_image = image_rgb8.view(x0, y0, x1 - x0, y1 - y0).to_image();
|
||||||
|
patches.push(DynamicImage::from(sub_image));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(nw, nh)
|
||||||
|
} else {
|
||||||
|
(1, 1)
|
||||||
|
};
|
||||||
|
patches.push(image.clone());
|
||||||
|
|
||||||
|
(patches, (nw, nh))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inference(&mut self, xs: Xs) -> Result<Xs> {
|
||||||
|
self.engine.run(xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn process_one(&mut self, x: &DynamicImage) -> Result<(X, (u32, u32))> {
|
||||||
|
let (patches, nw_nh) = Self::create_patches(x, (self.width as _, self.height as _));
|
||||||
|
let patches = self.processor.process_images(&patches)?.insert_axis(0)?;
|
||||||
|
|
||||||
|
Ok((patches, (nw_nh)))
|
||||||
|
}
|
||||||
|
}
|
4
src/models/smolvlm/mod.rs
Normal file
4
src/models/smolvlm/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mod config;
|
||||||
|
mod r#impl;
|
||||||
|
|
||||||
|
pub use r#impl::SmolVLM;
|
@ -2,6 +2,12 @@
|
|||||||
#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)]
|
#[derive(aksr::Builder, Debug, Clone, Default, PartialEq)]
|
||||||
pub struct Text(String);
|
pub struct Text(String);
|
||||||
|
|
||||||
|
impl std::fmt::Display for Text {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for Text {
|
impl std::ops::Deref for Text {
|
||||||
type Target = String;
|
type Target = String;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user